From 14ed1096929e34d8ba51931d4661e3860037051d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Gr=C3=A4fe?= Date: Fri, 17 Mar 2023 13:13:31 +0100 Subject: [PATCH 1/2] tcp: Add keepalive option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The reason the actual work is done when pending() returns PendingStatus::Ready is that on Windows the option cannot be set during connect (when connect() was called but the socket is not yet connected). Also this way it applies to both outgoing and incoming connections. Signed-off-by: Konrad Gräfe --- src/adapters/tcp.rs | 76 +++++++++++++++++++++++++++++++++------- src/network/transport.rs | 16 +++++---- 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/src/adapters/tcp.rs b/src/adapters/tcp.rs index d8a1062..a781d0d 100644 --- a/src/adapters/tcp.rs +++ b/src/adapters/tcp.rs @@ -7,16 +7,34 @@ use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen}; use mio::net::{TcpListener, TcpStream}; use mio::event::{Source}; +use socket2::{Socket, TcpKeepalive}; + use std::net::{SocketAddr}; use std::io::{self, ErrorKind, Read, Write}; use std::ops::{Deref}; -use std::mem::{MaybeUninit}; +use std::mem::{forget, MaybeUninit}; +#[cfg(target_os = "windows")] +use std::os::windows::io::{FromRawSocket, AsRawSocket}; +#[cfg(not(target_os = "windows"))] +use std::os::{fd::AsRawFd, unix::io::FromRawFd}; /// Size of the internal reading buffer. /// It implies that at most the generated [`crate::network::NetEvent::Message`] /// will contains a chunk of data of this value. pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1 +#[derive(Clone, Debug, Default)] +pub struct TcpConnectConfig { + /// Enables TCP keepalive settings on the socket. + pub keepalive: Option, +} + +#[derive(Clone, Debug, Default)] +pub struct TcpListenConfig { + /// Enables TCP keepalive settings on client connection sockets. + pub keepalive: Option, +} + pub(crate) struct TcpAdapter; impl Adapter for TcpAdapter { type Remote = RemoteResource; @@ -25,12 +43,7 @@ impl Adapter for TcpAdapter { pub(crate) struct RemoteResource { stream: TcpStream, -} - -impl From for RemoteResource { - fn from(stream: TcpStream) -> Self { - Self { stream } - } + keepalive: Option, } impl Resource for RemoteResource { @@ -41,13 +54,21 @@ impl Resource for RemoteResource { impl Remote for RemoteResource { fn connect_with( - _: TransportConnect, + config: TransportConnect, remote_addr: RemoteAddr, ) -> io::Result> { + let config = match config { + TransportConnect::Tcp(config) => config, + _ => panic!("Internal error: Got wrong config"), + }; let peer_addr = *remote_addr.socket_addr(); let stream = TcpStream::connect(peer_addr)?; let local_addr = stream.local_addr()?; - Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr }) + Ok(ConnectionInfo { + remote: Self { stream, keepalive: config.keepalive }, + local_addr, + peer_addr, + }) } fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus { @@ -102,7 +123,25 @@ impl Remote for RemoteResource { } fn pending(&self, _readiness: Readiness) -> PendingStatus { - check_stream_ready(&self.stream) + let status = check_stream_ready(&self.stream); + + if status == PendingStatus::Ready { + if let Some(keepalive) = &self.keepalive { + #[cfg(target_os = "windows")] + let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) }; + #[cfg(not(target_os = "windows"))] + let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) }; + + if let Err(e) = socket.set_tcp_keepalive(keepalive) { + log::warn!("TCP set keepalive error: {}", e); + } + + // Don't drop so the underlying socket is not closed. + forget(socket); + } + } + + status } } @@ -123,6 +162,7 @@ pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus { pub(crate) struct LocalResource { listener: TcpListener, + keepalive: Option, } impl Resource for LocalResource { @@ -134,16 +174,26 @@ impl Resource for LocalResource { impl Local for LocalResource { type Remote = RemoteResource; - fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result> { + fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result> { + let config = match config { + TransportListen::Tcp(config) => config, + _ => panic!("Internal error: Got wrong config"), + }; let listener = TcpListener::bind(addr)?; let local_addr = listener.local_addr().unwrap(); - Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr }) + Ok(ListeningInfo { + local: { LocalResource { listener, keepalive: config.keepalive } }, + local_addr, + }) } fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) { loop { match self.listener.accept() { - Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())), + Ok((stream, addr)) => accept_remote(AcceptedType::Remote( + addr, + RemoteResource { stream, keepalive: self.keepalive.clone() }, + )), Err(ref err) if err.kind() == ErrorKind::WouldBlock => break, Err(ref err) if err.kind() == ErrorKind::Interrupted => continue, Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen diff --git a/src/network/transport.rs b/src/network/transport.rs index c804e96..8736d22 100644 --- a/src/network/transport.rs +++ b/src/network/transport.rs @@ -1,7 +1,7 @@ use super::loader::{DriverLoader}; #[cfg(feature = "tcp")] -use crate::adapters::tcp::{TcpAdapter}; +use crate::adapters::tcp::{TcpAdapter, TcpConnectConfig, TcpListenConfig}; #[cfg(feature = "tcp")] use crate::adapters::framed_tcp::{FramedTcpAdapter}; #[cfg(feature = "udp")] @@ -157,9 +157,10 @@ impl std::fmt::Display for Transport { } } +#[derive(Debug)] pub enum TransportConnect { #[cfg(feature = "tcp")] - Tcp, + Tcp(TcpConnectConfig), #[cfg(feature = "tcp")] FramedTcp, #[cfg(feature = "udp")] @@ -172,7 +173,7 @@ impl TransportConnect { pub fn id(&self) -> u8 { let transport = match self { #[cfg(feature = "tcp")] - Self::Tcp => Transport::Tcp, + Self::Tcp(_) => Transport::Tcp, #[cfg(feature = "tcp")] Self::FramedTcp => Transport::FramedTcp, #[cfg(feature = "udp")] @@ -189,7 +190,7 @@ impl From for TransportConnect { fn from(transport: Transport) -> Self { match transport { #[cfg(feature = "tcp")] - Transport::Tcp => Self::Tcp, + Transport::Tcp => Self::Tcp(TcpConnectConfig::default()), #[cfg(feature = "tcp")] Transport::FramedTcp => Self::FramedTcp, #[cfg(feature = "udp")] @@ -200,9 +201,10 @@ impl From for TransportConnect { } } +#[derive(Debug)] pub enum TransportListen { #[cfg(feature = "tcp")] - Tcp, + Tcp(TcpListenConfig), #[cfg(feature = "tcp")] FramedTcp, #[cfg(feature = "udp")] @@ -215,7 +217,7 @@ impl TransportListen { pub fn id(&self) -> u8 { let transport = match self { #[cfg(feature = "tcp")] - Self::Tcp => Transport::Tcp, + Self::Tcp(_) => Transport::Tcp, #[cfg(feature = "tcp")] Self::FramedTcp => Transport::FramedTcp, #[cfg(feature = "udp")] @@ -232,7 +234,7 @@ impl From for TransportListen { fn from(transport: Transport) -> Self { match transport { #[cfg(feature = "tcp")] - Transport::Tcp => Self::Tcp, + Transport::Tcp => Self::Tcp(TcpListenConfig::default()), #[cfg(feature = "tcp")] Transport::FramedTcp => Self::FramedTcp, #[cfg(feature = "udp")] From 2c720711bc07eb8c5d3a273c09004b467d5b7c28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Gr=C3=A4fe?= Date: Fri, 17 Mar 2023 13:32:47 +0100 Subject: [PATCH 2/2] framed_tcp: Add TCP keepalive option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Konrad Gräfe --- src/adapters/framed_tcp.rs | 76 +++++++++++++++++++++++++++++++++----- src/network/transport.rs | 14 +++---- 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/adapters/framed_tcp.rs b/src/adapters/framed_tcp.rs index c5003ba..fff93b3 100644 --- a/src/adapters/framed_tcp.rs +++ b/src/adapters/framed_tcp.rs @@ -8,14 +8,32 @@ use crate::util::encoding::{self, Decoder, MAX_ENCODED_SIZE}; use mio::net::{TcpListener, TcpStream}; use mio::event::{Source}; +use socket2::{Socket, TcpKeepalive}; + use std::net::{SocketAddr}; use std::io::{self, ErrorKind, Read, Write}; use std::ops::{Deref}; use std::cell::{RefCell}; -use std::mem::{MaybeUninit}; +use std::mem::{forget, MaybeUninit}; +#[cfg(target_os = "windows")] +use std::os::windows::io::{FromRawSocket, AsRawSocket}; +#[cfg(not(target_os = "windows"))] +use std::os::{fd::AsRawFd, unix::io::FromRawFd}; const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1 +#[derive(Clone, Debug, Default)] +pub struct FramedTcpConnectConfig { + /// Enables TCP keepalive settings on the socket. + pub keepalive: Option, +} + +#[derive(Clone, Debug, Default)] +pub struct FramedTcpListenConfig { + /// Enables TCP keepalive settings on client connection sockets. + pub keepalive: Option, +} + pub(crate) struct FramedTcpAdapter; impl Adapter for FramedTcpAdapter { type Remote = RemoteResource; @@ -25,6 +43,7 @@ impl Adapter for FramedTcpAdapter { pub(crate) struct RemoteResource { stream: TcpStream, decoder: RefCell, + keepalive: Option, } // SAFETY: @@ -32,9 +51,9 @@ pub(crate) struct RemoteResource { // that will be called always from the same thread. This way, we save the cost of a Mutex. unsafe impl Sync for RemoteResource {} -impl From for RemoteResource { - fn from(stream: TcpStream) -> Self { - Self { stream, decoder: RefCell::new(Decoder::default()) } +impl RemoteResource { + fn new(stream: TcpStream, keepalive: Option) -> Self { + Self { stream, decoder: RefCell::new(Decoder::default()), keepalive } } } @@ -46,13 +65,21 @@ impl Resource for RemoteResource { impl Remote for RemoteResource { fn connect_with( - _: TransportConnect, + config: TransportConnect, remote_addr: RemoteAddr, ) -> io::Result> { + let config = match config { + TransportConnect::FramedTcp(config) => config, + _ => panic!("Internal error: Got wrong config"), + }; let peer_addr = *remote_addr.socket_addr(); let stream = TcpStream::connect(peer_addr)?; let local_addr = stream.local_addr()?; - Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr }) + Ok(ConnectionInfo { + remote: RemoteResource::new(stream, config.keepalive), + local_addr, + peer_addr, + }) } fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus { @@ -115,12 +142,31 @@ impl Remote for RemoteResource { } fn pending(&self, _readiness: Readiness) -> PendingStatus { - super::tcp::check_stream_ready(&self.stream) + let status = super::tcp::check_stream_ready(&self.stream); + + if status == PendingStatus::Ready { + if let Some(keepalive) = &self.keepalive { + #[cfg(target_os = "windows")] + let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) }; + #[cfg(not(target_os = "windows"))] + let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) }; + + if let Err(e) = socket.set_tcp_keepalive(keepalive) { + log::warn!("TCP set keepalive error: {}", e); + } + + // Don't drop so the underlying socket is not closed. + forget(socket); + } + } + + status } } pub(crate) struct LocalResource { listener: TcpListener, + keepalive: Option, } impl Resource for LocalResource { @@ -132,16 +178,26 @@ impl Resource for LocalResource { impl Local for LocalResource { type Remote = RemoteResource; - fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result> { + fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result> { + let config = match config { + TransportListen::FramedTcp(config) => config, + _ => panic!("Internal error: Got wrong config"), + }; let listener = TcpListener::bind(addr)?; let local_addr = listener.local_addr().unwrap(); - Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr }) + Ok(ListeningInfo { + local: { LocalResource { listener, keepalive: config.keepalive } }, + local_addr, + }) } fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) { loop { match self.listener.accept() { - Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())), + Ok((stream, addr)) => accept_remote(AcceptedType::Remote( + addr, + RemoteResource::new(stream, self.keepalive.clone()), + )), Err(ref err) if err.kind() == ErrorKind::WouldBlock => break, Err(ref err) if err.kind() == ErrorKind::Interrupted => continue, Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen diff --git a/src/network/transport.rs b/src/network/transport.rs index 8736d22..b4d424a 100644 --- a/src/network/transport.rs +++ b/src/network/transport.rs @@ -3,7 +3,7 @@ use super::loader::{DriverLoader}; #[cfg(feature = "tcp")] use crate::adapters::tcp::{TcpAdapter, TcpConnectConfig, TcpListenConfig}; #[cfg(feature = "tcp")] -use crate::adapters::framed_tcp::{FramedTcpAdapter}; +use crate::adapters::framed_tcp::{FramedTcpAdapter, FramedTcpConnectConfig, FramedTcpListenConfig}; #[cfg(feature = "udp")] use crate::adapters::udp::{self, UdpAdapter, UdpConnectConfig, UdpListenConfig}; #[cfg(feature = "websocket")] @@ -162,7 +162,7 @@ pub enum TransportConnect { #[cfg(feature = "tcp")] Tcp(TcpConnectConfig), #[cfg(feature = "tcp")] - FramedTcp, + FramedTcp(FramedTcpConnectConfig), #[cfg(feature = "udp")] Udp(UdpConnectConfig), #[cfg(feature = "websocket")] @@ -175,7 +175,7 @@ impl TransportConnect { #[cfg(feature = "tcp")] Self::Tcp(_) => Transport::Tcp, #[cfg(feature = "tcp")] - Self::FramedTcp => Transport::FramedTcp, + Self::FramedTcp(_) => Transport::FramedTcp, #[cfg(feature = "udp")] Self::Udp(_) => Transport::Udp, #[cfg(feature = "websocket")] @@ -192,7 +192,7 @@ impl From for TransportConnect { #[cfg(feature = "tcp")] Transport::Tcp => Self::Tcp(TcpConnectConfig::default()), #[cfg(feature = "tcp")] - Transport::FramedTcp => Self::FramedTcp, + Transport::FramedTcp => Self::FramedTcp(FramedTcpConnectConfig::default()), #[cfg(feature = "udp")] Transport::Udp => Self::Udp(UdpConnectConfig::default()), #[cfg(feature = "websocket")] @@ -206,7 +206,7 @@ pub enum TransportListen { #[cfg(feature = "tcp")] Tcp(TcpListenConfig), #[cfg(feature = "tcp")] - FramedTcp, + FramedTcp(FramedTcpListenConfig), #[cfg(feature = "udp")] Udp(UdpListenConfig), #[cfg(feature = "websocket")] @@ -219,7 +219,7 @@ impl TransportListen { #[cfg(feature = "tcp")] Self::Tcp(_) => Transport::Tcp, #[cfg(feature = "tcp")] - Self::FramedTcp => Transport::FramedTcp, + Self::FramedTcp(_) => Transport::FramedTcp, #[cfg(feature = "udp")] Self::Udp(_) => Transport::Udp, #[cfg(feature = "websocket")] @@ -236,7 +236,7 @@ impl From for TransportListen { #[cfg(feature = "tcp")] Transport::Tcp => Self::Tcp(TcpListenConfig::default()), #[cfg(feature = "tcp")] - Transport::FramedTcp => Self::FramedTcp, + Transport::FramedTcp => Self::FramedTcp(FramedTcpListenConfig::default()), #[cfg(feature = "udp")] Transport::Udp => Self::Udp(UdpListenConfig::default()), #[cfg(feature = "websocket")]