diff --git a/src/web/api/mod.rs b/src/web/api/mod.rs index f2d5b996..0a4e108e 100644 --- a/src/web/api/mod.rs +++ b/src/web/api/mod.rs @@ -30,12 +30,6 @@ pub struct Running { pub api_server: Option>>, } -#[must_use] -#[derive(Debug)] -pub struct ServerStartedMessage { - pub socket_addr: SocketAddr, -} - /// Starts the API server. #[must_use] pub async fn start(app_data: Arc, net_ip: &str, net_port: u16, implementation: &Version) -> api::Running { diff --git a/src/web/api/server/custom_axum.rs b/src/web/api/server/custom_axum.rs new file mode 100644 index 00000000..5705ef24 --- /dev/null +++ b/src/web/api/server/custom_axum.rs @@ -0,0 +1,275 @@ +//! Wrapper for Axum server to add timeouts. +//! +//! Copyright (c) Eray Karatay ([@programatik29](https://github.com/programatik29)). +//! +//! See: . +//! +//! If a client opens a HTTP connection and it does not send any requests, the +//! connection is closed after a timeout. You can test it with: +//! +//! ```text +//! telnet 127.0.0.1 1212 +//! Trying 127.0.0.1... +//! Connected to 127.0.0.1. +//! Escape character is '^]'. +//! Connection closed by foreign host. +//! ``` +//! +//! If you want to know more about Axum and timeouts see . +use std::future::Ready; +use std::io::ErrorKind; +use std::net::TcpListener; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use axum_server::accept::Accept; +use axum_server::tls_rustls::{RustlsAcceptor, RustlsConfig}; +use axum_server::Server; +use futures_util::{ready, Future}; +use http_body::{Body, Frame}; +use hyper::Response; +use hyper_util::rt::TokioTimer; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; +use tokio::time::{Instant, Sleep}; +use tower::Service; + +const HTTP1_HEADER_READ_TIMEOUT: Duration = Duration::from_secs(5); +const HTTP2_KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(5); +const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(5); + +#[must_use] +pub fn from_tcp_with_timeouts(socket: TcpListener) -> Server { + add_timeouts(axum_server::from_tcp(socket)) +} + +#[must_use] +pub fn from_tcp_rustls_with_timeouts(socket: TcpListener, tls: RustlsConfig) -> Server { + add_timeouts(axum_server::from_tcp_rustls(socket, tls)) +} + +fn add_timeouts(mut server: Server) -> Server { + server.http_builder().http1().timer(TokioTimer::new()); + server.http_builder().http2().timer(TokioTimer::new()); + + server.http_builder().http1().header_read_timeout(HTTP1_HEADER_READ_TIMEOUT); + server + .http_builder() + .http2() + .keep_alive_timeout(HTTP2_KEEP_ALIVE_TIMEOUT) + .keep_alive_interval(HTTP2_KEEP_ALIVE_INTERVAL); + + server +} + +#[derive(Clone)] +pub struct TimeoutAcceptor; + +impl Accept for TimeoutAcceptor { + type Stream = TimeoutStream; + type Service = TimeoutService; + type Future = Ready>; + + fn accept(&self, stream: I, service: S) -> Self::Future { + let (tx, rx) = mpsc::unbounded_channel(); + + let stream = TimeoutStream::new(stream, HTTP1_HEADER_READ_TIMEOUT, rx); + let service = TimeoutService::new(service, tx); + + std::future::ready(Ok((stream, service))) + } +} + +#[derive(Clone)] +pub struct TimeoutService { + inner: S, + sender: UnboundedSender, +} + +impl TimeoutService { + fn new(inner: S, sender: UnboundedSender) -> Self { + Self { inner, sender } + } +} + +impl Service for TimeoutService +where + S: Service>, +{ + type Response = Response>; + type Error = S::Error; + type Future = TimeoutServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + // send timer wait signal + let _ = self.sender.send(TimerSignal::Wait); + + TimeoutServiceFuture::new(self.inner.call(req), self.sender.clone()) + } +} + +pin_project! { + pub struct TimeoutServiceFuture { + #[pin] + inner: F, + sender: Option>, + } +} + +impl TimeoutServiceFuture { + fn new(inner: F, sender: UnboundedSender) -> Self { + Self { + inner, + sender: Some(sender), + } + } +} + +impl Future for TimeoutServiceFuture +where + F: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.inner.poll(cx).map(|result| { + result.map(|response| { + response.map(|body| TimeoutBody::new(body, this.sender.take().expect("future polled after ready"))) + }) + }) + } +} + +enum TimerSignal { + Wait, + Reset, +} + +pin_project! { + pub struct TimeoutBody { + #[pin] + inner: B, + sender: UnboundedSender, + } +} + +impl TimeoutBody { + fn new(inner: B, sender: UnboundedSender) -> Self { + Self { inner, sender } + } +} + +impl Body for TimeoutBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll, Self::Error>>> { + let this = self.project(); + let option = ready!(this.inner.poll_frame(cx)); + + if option.is_none() { + let _ = this.sender.send(TimerSignal::Reset); + } + + Poll::Ready(option) + } + + fn is_end_stream(&self) -> bool { + let is_end_stream = self.inner.is_end_stream(); + + if is_end_stream { + let _ = self.sender.send(TimerSignal::Reset); + } + + is_end_stream + } + + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +} + +pub struct TimeoutStream { + inner: IO, + // hyper requires unpin + sleep: Pin>, + duration: Duration, + waiting: bool, + receiver: UnboundedReceiver, + finished: bool, +} + +impl TimeoutStream { + fn new(inner: IO, duration: Duration, receiver: UnboundedReceiver) -> Self { + Self { + inner, + sleep: Box::pin(tokio::time::sleep(duration)), + duration, + waiting: false, + receiver, + finished: false, + } + } +} + +impl AsyncRead for TimeoutStream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + if !self.finished { + match Pin::new(&mut self.receiver).poll_recv(cx) { + // reset the timer + Poll::Ready(Some(TimerSignal::Reset)) => { + self.waiting = false; + + let deadline = Instant::now() + self.duration; + self.sleep.as_mut().reset(deadline); + } + // enter waiting mode (for response body last chunk) + Poll::Ready(Some(TimerSignal::Wait)) => self.waiting = true, + Poll::Ready(None) => self.finished = true, + Poll::Pending => (), + } + } + + if !self.waiting { + // return error if timer is elapsed + if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) { + return Poll::Ready(Err(std::io::Error::new(ErrorKind::TimedOut, "request header read timed out"))); + } + } + + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for TimeoutStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} diff --git a/src/web/api/server/mod.rs b/src/web/api/server/mod.rs index e6943e87..a327361e 100644 --- a/src/web/api/server/mod.rs +++ b/src/web/api/server/mod.rs @@ -1,15 +1,20 @@ +pub mod custom_axum; +pub mod signals; pub mod v1; use std::net::SocketAddr; use std::sync::Arc; +use axum_server::Handle; use log::info; -use tokio::net::TcpListener; -use tokio::sync::oneshot::{self, Sender}; +use tokio::sync::oneshot::{Receiver, Sender}; use v1::routes::router; -use super::{Running, ServerStartedMessage}; +use self::signals::{Halted, Started}; +use super::Running; use crate::common::AppData; +use crate::web::api::server::custom_axum::TimeoutAcceptor; +use crate::web::api::server::signals::graceful_shutdown; /// Starts the API server. /// @@ -21,13 +26,14 @@ pub async fn start(app_data: Arc, net_ip: &str, net_port: u16) -> Runni .parse() .expect("API server socket address to be valid."); - let (tx, rx) = oneshot::channel::(); + let (tx_start, rx) = tokio::sync::oneshot::channel::(); + let (_tx_halt, rx_halt) = tokio::sync::oneshot::channel::(); // Run the API server let join_handle = tokio::spawn(async move { info!("Starting API server with net config: {} ...", config_socket_addr); - start_server(config_socket_addr, app_data.clone(), tx).await; + start_server(config_socket_addr, app_data.clone(), tx_start, rx_halt).await; info!("API server stopped"); @@ -46,27 +52,34 @@ pub async fn start(app_data: Arc, net_ip: &str, net_port: u16) -> Runni } } -async fn start_server(config_socket_addr: SocketAddr, app_data: Arc, tx: Sender) { - let tcp_listener = TcpListener::bind(config_socket_addr) - .await - .expect("tcp listener to bind to a socket address"); +async fn start_server( + config_socket_addr: SocketAddr, + app_data: Arc, + tx_start: Sender, + rx_halt: Receiver, +) { + let router = router(app_data); + let socket = std::net::TcpListener::bind(config_socket_addr).expect("Could not bind tcp_listener to address."); + let address = socket.local_addr().expect("Could not get local_addr from tcp_listener."); - let bound_addr = tcp_listener - .local_addr() - .expect("tcp listener to be bound to a socket address."); + let handle = Handle::new(); - info!("API server listening on http://{}", bound_addr); // # DevSkim: ignore DS137138 + tokio::task::spawn(graceful_shutdown( + handle.clone(), + rx_halt, + format!("Shutting down API server on socket address: {address}"), + )); - let app = router(app_data); + info!("API server listening on http://{}", address); // # DevSkim: ignore DS137138 - tx.send(ServerStartedMessage { socket_addr: bound_addr }) + tx_start + .send(Started { socket_addr: address }) .expect("the API server should not be dropped"); - axum::serve(tcp_listener, app.into_make_service_with_connect_info::()) - .with_graceful_shutdown(async move { - tokio::signal::ctrl_c().await.expect("Failed to listen to shutdown signal."); - info!("Stopping API server on http://{} ...", bound_addr); // # DevSkim: ignore DS137138 - }) + custom_axum::from_tcp_with_timeouts(socket) + .handle(handle) + .acceptor(TimeoutAcceptor) + .serve(router.into_make_service_with_connect_info::()) .await .expect("API server should be running"); } diff --git a/src/web/api/server/signals.rs b/src/web/api/server/signals.rs new file mode 100644 index 00000000..872d6094 --- /dev/null +++ b/src/web/api/server/signals.rs @@ -0,0 +1,88 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use derive_more::Display; +use log::info; +use tokio::time::sleep; + +/// This is the message that the "launcher" spawned task sends to the main +/// application process to notify the service was successfully started. +#[derive(Copy, Clone, Debug, Display)] +pub struct Started { + pub socket_addr: SocketAddr, +} + +/// This is the message that the "launcher" spawned task receives from the main +/// application process to notify the service to shutdown. +#[derive(Copy, Clone, Debug, Display)] +pub enum Halted { + Normal, +} + +pub async fn graceful_shutdown(handle: axum_server::Handle, rx_halt: tokio::sync::oneshot::Receiver, message: String) { + shutdown_signal_with_message(rx_halt, message).await; + + info!("Sending graceful shutdown signal"); + handle.graceful_shutdown(Some(Duration::from_secs(90))); + + println!("!! shuting down in 90 seconds !!"); + + loop { + sleep(Duration::from_secs(1)).await; + + info!("remaining alive connections: {}", handle.connection_count()); + } +} + +/// Same as `shutdown_signal()`, but shows a message when it resolves. +pub async fn shutdown_signal_with_message(rx_halt: tokio::sync::oneshot::Receiver, message: String) { + shutdown_signal(rx_halt).await; + + info!("{message}"); +} + +/// Resolves when the `stop_receiver` or the `global_shutdown_signal()` resolves. +/// +/// # Panics +/// +/// Will panic if the `stop_receiver` resolves with an error. +pub async fn shutdown_signal(rx_halt: tokio::sync::oneshot::Receiver) { + let halt = async { + match rx_halt.await { + Ok(signal) => signal, + Err(err) => panic!("Failed to install stop signal: {err}"), + } + }; + + tokio::select! { + signal = halt => { info!("Halt signal processed: {}", signal) }, + () = global_shutdown_signal() => { info!("Global shutdown signal processed") } + } +} + +/// Resolves on `ctrl_c` or the `terminate` signal. +/// +/// # Panics +/// +/// Will panic if the `ctrl_c` or `terminate` signal resolves with an error. +pub async fn global_shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + () = ctrl_c => {}, + () = terminate => {} + } +}