Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TCP keepalive option #143

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions src/adapters/framed_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,32 @@ use crate::util::encoding::{self, Decoder, MAX_ENCODED_SIZE};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};

use socket2::{Socket, TcpKeepalive};

use std::net::{SocketAddr};
use std::io::{self, ErrorKind, Read, Write};
use std::ops::{Deref};
use std::cell::{RefCell};
use std::mem::{MaybeUninit};
use std::mem::{forget, MaybeUninit};
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};

const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1

#[derive(Clone, Debug, Default)]
pub struct FramedTcpConnectConfig {
/// Enables TCP keepalive settings on the socket.
pub keepalive: Option<TcpKeepalive>,
}

#[derive(Clone, Debug, Default)]
pub struct FramedTcpListenConfig {
/// Enables TCP keepalive settings on client connection sockets.
pub keepalive: Option<TcpKeepalive>,
}

pub(crate) struct FramedTcpAdapter;
impl Adapter for FramedTcpAdapter {
type Remote = RemoteResource;
Expand All @@ -25,16 +43,17 @@ impl Adapter for FramedTcpAdapter {
pub(crate) struct RemoteResource {
stream: TcpStream,
decoder: RefCell<Decoder>,
keepalive: Option<TcpKeepalive>,
}

// SAFETY:
// That RefCell<Decoder> can be used with Sync because the decoder is only used in the read_event,
// that will be called always from the same thread. This way, we save the cost of a Mutex.
unsafe impl Sync for RemoteResource {}

impl From<TcpStream> for RemoteResource {
fn from(stream: TcpStream) -> Self {
Self { stream, decoder: RefCell::new(Decoder::default()) }
impl RemoteResource {
fn new(stream: TcpStream, keepalive: Option<TcpKeepalive>) -> Self {
Self { stream, decoder: RefCell::new(Decoder::default()), keepalive }
}
}

Expand All @@ -46,13 +65,21 @@ impl Resource for RemoteResource {

impl Remote for RemoteResource {
fn connect_with(
_: TransportConnect,
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::FramedTcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let stream = TcpStream::connect(peer_addr)?;
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr })
Ok(ConnectionInfo {
remote: RemoteResource::new(stream, config.keepalive),
local_addr,
peer_addr,
})
}

fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
Expand Down Expand Up @@ -115,12 +142,31 @@ impl Remote for RemoteResource {
}

fn pending(&self, _readiness: Readiness) -> PendingStatus {
super::tcp::check_stream_ready(&self.stream)
let status = super::tcp::check_stream_ready(&self.stream);

if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };

if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}

// Don't drop so the underlying socket is not closed.
forget(socket);
}
}

status
}
}

pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}

impl Resource for LocalResource {
Expand All @@ -132,16 +178,26 @@ impl Resource for LocalResource {
impl Local for LocalResource {
type Remote = RemoteResource;

fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::FramedTcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr })
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}

fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())),
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource::new(stream, self.keepalive.clone()),
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
Expand Down
76 changes: 63 additions & 13 deletions src/adapters/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,34 @@ use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};

use socket2::{Socket, TcpKeepalive};

use std::net::{SocketAddr};
use std::io::{self, ErrorKind, Read, Write};
use std::ops::{Deref};
use std::mem::{MaybeUninit};
use std::mem::{forget, MaybeUninit};
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};

/// Size of the internal reading buffer.
/// It implies that at most the generated [`crate::network::NetEvent::Message`]
/// will contains a chunk of data of this value.
pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1

#[derive(Clone, Debug, Default)]
pub struct TcpConnectConfig {
/// Enables TCP keepalive settings on the socket.
pub keepalive: Option<TcpKeepalive>,
}

#[derive(Clone, Debug, Default)]
pub struct TcpListenConfig {
/// Enables TCP keepalive settings on client connection sockets.
pub keepalive: Option<TcpKeepalive>,
}

pub(crate) struct TcpAdapter;
impl Adapter for TcpAdapter {
type Remote = RemoteResource;
Expand All @@ -25,12 +43,7 @@ impl Adapter for TcpAdapter {

pub(crate) struct RemoteResource {
stream: TcpStream,
}

impl From<TcpStream> for RemoteResource {
fn from(stream: TcpStream) -> Self {
Self { stream }
}
keepalive: Option<TcpKeepalive>,
}

impl Resource for RemoteResource {
Expand All @@ -41,13 +54,21 @@ impl Resource for RemoteResource {

impl Remote for RemoteResource {
fn connect_with(
_: TransportConnect,
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let stream = TcpStream::connect(peer_addr)?;
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr })
Ok(ConnectionInfo {
remote: Self { stream, keepalive: config.keepalive },
local_addr,
peer_addr,
})
}

fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
Expand Down Expand Up @@ -102,7 +123,25 @@ impl Remote for RemoteResource {
}

fn pending(&self, _readiness: Readiness) -> PendingStatus {
check_stream_ready(&self.stream)
let status = check_stream_ready(&self.stream);

if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };

if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}

// Don't drop so the underlying socket is not closed.
forget(socket);
}
}

status
Comment on lines +128 to +144
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why do this here instead of in connect()/listen() methods.

Could we get rid of the unsafe code building the socket with socket2, and once it's built, get the stream? instead of building the stream and getting the socket from it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot do that in listen() because it concerns the accepted socket, not the listener socket. We could do it in accept() but mio does quite a bit of platform dependent stuff in the UNIX path which contains unsafe as well. We would loose that or have to copy and maintain it ourselves. I'd rather avoid that and live with the unsafe code.

(The Windows path just calls std::net::TcpListener::accept(), though.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks for your thoughts!

}
}

Expand All @@ -123,6 +162,7 @@ pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus {

pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}

impl Resource for LocalResource {
Expand All @@ -134,16 +174,26 @@ impl Resource for LocalResource {
impl Local for LocalResource {
type Remote = RemoteResource;

fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr })
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}

fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())),
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource { stream, keepalive: self.keepalive.clone() },
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
Expand Down
30 changes: 16 additions & 14 deletions src/network/transport.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::loader::{DriverLoader};

#[cfg(feature = "tcp")]
use crate::adapters::tcp::{TcpAdapter};
use crate::adapters::tcp::{TcpAdapter, TcpConnectConfig, TcpListenConfig};
#[cfg(feature = "tcp")]
use crate::adapters::framed_tcp::{FramedTcpAdapter};
use crate::adapters::framed_tcp::{FramedTcpAdapter, FramedTcpConnectConfig, FramedTcpListenConfig};
#[cfg(feature = "udp")]
use crate::adapters::udp::{self, UdpAdapter, UdpConnectConfig, UdpListenConfig};
#[cfg(feature = "websocket")]
Expand Down Expand Up @@ -157,11 +157,12 @@ impl std::fmt::Display for Transport {
}
}

#[derive(Debug)]
pub enum TransportConnect {
#[cfg(feature = "tcp")]
Tcp,
Tcp(TcpConnectConfig),
#[cfg(feature = "tcp")]
FramedTcp,
FramedTcp(FramedTcpConnectConfig),
#[cfg(feature = "udp")]
Udp(UdpConnectConfig),
#[cfg(feature = "websocket")]
Expand All @@ -172,9 +173,9 @@ impl TransportConnect {
pub fn id(&self) -> u8 {
let transport = match self {
#[cfg(feature = "tcp")]
Self::Tcp => Transport::Tcp,
Self::Tcp(_) => Transport::Tcp,
#[cfg(feature = "tcp")]
Self::FramedTcp => Transport::FramedTcp,
Self::FramedTcp(_) => Transport::FramedTcp,
#[cfg(feature = "udp")]
Self::Udp(_) => Transport::Udp,
#[cfg(feature = "websocket")]
Expand All @@ -189,9 +190,9 @@ impl From<Transport> for TransportConnect {
fn from(transport: Transport) -> Self {
match transport {
#[cfg(feature = "tcp")]
Transport::Tcp => Self::Tcp,
Transport::Tcp => Self::Tcp(TcpConnectConfig::default()),
#[cfg(feature = "tcp")]
Transport::FramedTcp => Self::FramedTcp,
Transport::FramedTcp => Self::FramedTcp(FramedTcpConnectConfig::default()),
#[cfg(feature = "udp")]
Transport::Udp => Self::Udp(UdpConnectConfig::default()),
#[cfg(feature = "websocket")]
Expand All @@ -200,11 +201,12 @@ impl From<Transport> for TransportConnect {
}
}

#[derive(Debug)]
pub enum TransportListen {
#[cfg(feature = "tcp")]
Tcp,
Tcp(TcpListenConfig),
#[cfg(feature = "tcp")]
FramedTcp,
FramedTcp(FramedTcpListenConfig),
#[cfg(feature = "udp")]
Udp(UdpListenConfig),
#[cfg(feature = "websocket")]
Expand All @@ -215,9 +217,9 @@ impl TransportListen {
pub fn id(&self) -> u8 {
let transport = match self {
#[cfg(feature = "tcp")]
Self::Tcp => Transport::Tcp,
Self::Tcp(_) => Transport::Tcp,
#[cfg(feature = "tcp")]
Self::FramedTcp => Transport::FramedTcp,
Self::FramedTcp(_) => Transport::FramedTcp,
#[cfg(feature = "udp")]
Self::Udp(_) => Transport::Udp,
#[cfg(feature = "websocket")]
Expand All @@ -232,9 +234,9 @@ impl From<Transport> for TransportListen {
fn from(transport: Transport) -> Self {
match transport {
#[cfg(feature = "tcp")]
Transport::Tcp => Self::Tcp,
Transport::Tcp => Self::Tcp(TcpListenConfig::default()),
#[cfg(feature = "tcp")]
Transport::FramedTcp => Self::FramedTcp,
Transport::FramedTcp => Self::FramedTcp(FramedTcpListenConfig::default()),
#[cfg(feature = "udp")]
Transport::Udp => Self::Udp(UdpListenConfig::default()),
#[cfg(feature = "websocket")]
Expand Down