diff --git a/examples/examples/jsonrpsee_as_service.rs b/examples/examples/jsonrpsee_as_service.rs index c738c05259..a87da4f730 100644 --- a/examples/examples/jsonrpsee_as_service.rs +++ b/examples/examples/jsonrpsee_as_service.rs @@ -36,6 +36,7 @@ use std::net::SocketAddr; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use futures::FutureExt; use hyper::header::AUTHORIZATION; use hyper::server::conn::AddrStream; use hyper::HeaderMap; @@ -45,16 +46,18 @@ use jsonrpsee::proc_macros::rpc; use jsonrpsee::server::middleware::rpc::{ResponseFuture, RpcServiceBuilder, RpcServiceT}; use jsonrpsee::server::{stop_channel, ServerHandle, StopHandle, TowerServiceBuilder}; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned, Request}; -use jsonrpsee::ws_client::HeaderValue; +use jsonrpsee::ws_client::{HeaderValue, WsClientBuilder}; use jsonrpsee::{MethodResponse, Methods}; use tower::Service; use tower_http::cors::CorsLayer; use tracing_subscriber::util::SubscriberInitExt; -#[derive(Default, Clone)] +#[derive(Default, Clone, Debug)] struct Metrics { - ws_connections: Arc, - http_connections: Arc, + opened_ws_connections: Arc, + closed_ws_connections: Arc, + http_calls: Arc, + success_http_calls: Arc, } #[derive(Clone)] @@ -106,7 +109,9 @@ async fn main() -> anyhow::Result<()> { let filter = tracing_subscriber::EnvFilter::try_from_default_env()?; tracing_subscriber::FmtSubscriber::builder().with_env_filter(filter).finish().try_init()?; - let handle = run_server(); + let metrics = Metrics::default(); + + let handle = run_server(metrics.clone()); tokio::spawn(handle.stopped()); { @@ -117,6 +122,14 @@ async fn main() -> anyhow::Result<()> { tracing::info!("response: {x}"); } + { + let client = WsClientBuilder::default().build("ws://127.0.0.1:9944").await.unwrap(); + + // Fails because the authorization header is missing. + let x = client.trusted_call().await.unwrap_err(); + tracing::info!("response: {x}"); + } + { let mut headers = HeaderMap::new(); headers.insert(AUTHORIZATION, HeaderValue::from_static("don't care in this example")); @@ -127,10 +140,12 @@ async fn main() -> anyhow::Result<()> { tracing::info!("response: {x}"); } + tracing::info!("{:?}", metrics); + Ok(()) } -fn run_server() -> ServerHandle { +fn run_server(metrics: Metrics) -> ServerHandle { use hyper::service::{make_service_fn, service_fn}; let addr = SocketAddr::from(([127, 0, 0, 1], 9944)); @@ -159,7 +174,7 @@ fn run_server() -> ServerHandle { let per_conn = PerConnection { methods: ().into_rpc().into(), stop_handle: stop_handle.clone(), - metrics: Metrics::default(), + metrics, svc_builder: jsonrpsee::server::Server::builder() .set_http_middleware(tower::ServiceBuilder::new().layer(CorsLayer::permissive())) .max_connections(33) @@ -185,15 +200,40 @@ fn run_server() -> ServerHandle { let mut svc = svc_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); - async move { - // You can't determine whether the websocket upgrade handshake failed or not here. - let rp = svc.call(req).await; - if is_websocket { - metrics.ws_connections.fetch_add(1, Ordering::Relaxed); - } else { - metrics.http_connections.fetch_add(1, Ordering::Relaxed); + if is_websocket { + // Utilize the session close future to know when the actual WebSocket + // session was closed. + let session_close = svc.on_session_closed(); + + // A little bit weird API but the response to HTTP request must be returned below + // and we spawn a task to register when the session is closed. + tokio::spawn(async move { + session_close.await; + tracing::info!("Closed WebSocket connection"); + metrics.closed_ws_connections.fetch_add(1, Ordering::Relaxed); + }); + + async move { + tracing::info!("Opened WebSocket connection"); + metrics.opened_ws_connections.fetch_add(1, Ordering::Relaxed); + svc.call(req).await + } + .boxed() + } else { + // HTTP. + async move { + tracing::info!("Opened HTTP connection"); + metrics.http_calls.fetch_add(1, Ordering::Relaxed); + let rp = svc.call(req).await; + + if rp.is_ok() { + metrics.success_http_calls.fetch_add(1, Ordering::Relaxed); + } + + tracing::info!("Closed HTTP connection"); + rp } - rp + .boxed() } })) } diff --git a/server/Cargo.toml b/server/Cargo.toml index 266497c840..5820c80584 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -23,7 +23,7 @@ serde_json = { version = "1", features = ["raw_value"] } soketto = { version = "0.7.1", features = ["http"] } tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7", features = ["compat"] } -tokio-stream = "0.1.7" +tokio-stream = { version = "0.1.7", features = ["sync"] } hyper = { version = "0.14", features = ["server", "http1", "http2"] } tower = "0.4.13" thiserror = "1" diff --git a/server/src/future.rs b/server/src/future.rs index d187ef58e2..d26b9cb936 100644 --- a/server/src/future.rs +++ b/server/src/future.rs @@ -30,10 +30,11 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use futures_util::{Stream, StreamExt}; +use futures_util::{Future, Stream, StreamExt}; use pin_project::pin_project; use tokio::sync::{watch, OwnedSemaphorePermit, Semaphore, TryAcquireError}; use tokio::time::Interval; +use tokio_stream::wrappers::BroadcastStream; /// Create channel to determine whether /// the server shall continue to run or not. @@ -157,3 +158,40 @@ impl Stream for IntervalStream { } } } + +#[derive(Debug, Clone)] +pub(crate) struct SessionClose(tokio::sync::broadcast::Sender<()>); + +impl SessionClose { + pub(crate) fn close(self) { + let _ = self.0.send(()); + } + + pub(crate) fn closed(&self) -> SessionClosedFuture { + SessionClosedFuture(BroadcastStream::new(self.0.subscribe())) + } +} + +/// A future that resolves when the connection has been closed. +#[derive(Debug)] +pub struct SessionClosedFuture(BroadcastStream<()>); + +impl Future for SessionClosedFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.0.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + // Only message is only sent and + // ignore can't keep up errors. + Poll::Ready(_) => Poll::Ready(()), + } + } +} + +pub(crate) fn session_close() -> (SessionClose, SessionClosedFuture) { + // SessionClosedFuture is closed after one message has been recevied + // and max one message is handled then it's closed. + let (tx, rx) = tokio::sync::broadcast::channel(1); + (SessionClose(tx), SessionClosedFuture(BroadcastStream::new(rx))) +} diff --git a/server/src/server.rs b/server/src/server.rs index 0f1b4d76d6..922269b0f1 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -33,7 +33,7 @@ use std::sync::Arc; use std::task::Poll; use std::time::Duration; -use crate::future::{ConnectionGuard, ServerHandle, StopHandle}; +use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::transport::ws::BackgroundTaskParams; use crate::transport::{http, ws}; @@ -501,6 +501,7 @@ impl TowerServiceBuilder { http_middleware: tower::ServiceBuilder, } +impl TowerService { + /// A future that returns when the connection has been closed. + /// + /// This method must be called before every [`TowerService::call`] + /// because the `SessionClosedFuture` may already been consumed or + /// not used. + pub fn on_session_closed(&mut self) -> SessionClosedFuture { + if let Some(n) = self.rpc_middleware.on_session_close.as_mut() { + // If it's called more then once another listener is created. + n.closed() + } else { + let (session_close, fut) = session_close(); + self.rpc_middleware.on_session_close = Some(session_close); + fut + } + } +} + impl hyper::service::Service> for TowerService where @@ -979,6 +998,7 @@ where pub struct TowerServiceNoHttp { inner: ServiceData, rpc_middleware: RpcServiceBuilder, + on_session_close: Option, } impl hyper::service::Service> for TowerServiceNoHttp @@ -1004,6 +1024,7 @@ where let conn_guard = &self.inner.conn_guard; let stop_handle = self.inner.stop_handle.clone(); let conn_id = self.inner.conn_id; + let on_session_close = self.on_session_close.take(); tracing::trace!(target: LOG_TARGET, "{:?}", request); @@ -1076,6 +1097,7 @@ where sink, rx, pending_calls_completed, + on_session_close, }; ws::background_task(params).await; @@ -1176,6 +1198,7 @@ fn process_connection<'a, RpcMiddleware, HttpMiddleware, U>( conn_guard: conn_guard.clone(), }, rpc_middleware, + on_session_close: None, }; let service = http_middleware.service(tower_service); diff --git a/server/src/tests/helpers.rs b/server/src/tests/helpers.rs index 413aa7bd0d..968f22d81b 100644 --- a/server/src/tests/helpers.rs +++ b/server/src/tests/helpers.rs @@ -1,11 +1,18 @@ -use std::fmt; +use std::error::Error as StdError; use std::net::SocketAddr; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::{fmt, sync::atomic::AtomicUsize}; -use crate::{RpcModule, ServerBuilder, ServerHandle}; +use crate::{stop_channel, RpcModule, Server, ServerBuilder, ServerHandle}; +use futures_util::FutureExt; +use hyper::server::conn::AddrStream; +use jsonrpsee_core::server::Methods; use jsonrpsee_core::{DeserializeOwned, RpcResult, StringError}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_types::{error::ErrorCode, ErrorObject, ErrorObjectOwned, Response, ResponseSuccess}; +use tower::Service; use tracing_subscriber::{EnvFilter, FmtSubscriber}; pub(crate) struct TestContext; @@ -194,3 +201,63 @@ impl From for ErrorObjectOwned { fn invalid_params() -> ErrorObjectOwned { ErrorCode::InvalidParams.into() } + +#[derive(Debug, Clone, Default)] +pub(crate) struct Metrics { + pub(crate) ws_sessions_opened: Arc, + pub(crate) ws_sessions_closed: Arc, +} + +pub(crate) fn ws_server_with_stats(metrics: Metrics) -> SocketAddr { + use hyper::service::{make_service_fn, service_fn}; + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let (stop_handle, server_handle) = stop_channel(); + let stop_handle2 = stop_handle.clone(); + + // And a MakeService to handle each connection... + let make_service = make_service_fn(move |_conn: &AddrStream| { + let stop_handle = stop_handle2.clone(); + let metrics = metrics.clone(); + + async move { + Ok::<_, Box>(service_fn(move |req| { + let is_websocket = crate::ws::is_upgrade_request(&req); + let metrics = metrics.clone(); + let stop_handle = stop_handle.clone(); + + let mut svc = + Server::builder().max_connections(33).to_service_builder().build(Methods::new(), stop_handle); + + if is_websocket { + // This should work for each callback. + let session_close1 = svc.on_session_closed(); + let session_close2 = svc.on_session_closed(); + + tokio::spawn(async move { + metrics.ws_sessions_opened.fetch_add(1, Ordering::SeqCst); + tokio::join!(session_close2, session_close1); + metrics.ws_sessions_closed.fetch_add(1, Ordering::SeqCst); + }); + + async move { svc.call(req).await }.boxed() + } else { + // HTTP. + async move { svc.call(req).await }.boxed() + } + })) + } + }); + + let server = hyper::Server::bind(&addr).serve(make_service); + + let addr = server.local_addr(); + + tokio::spawn(async move { + let graceful = server.with_graceful_shutdown(async move { stop_handle.shutdown().await }); + graceful.await.unwrap(); + drop(server_handle) + }); + + addr +} diff --git a/server/src/tests/ws.rs b/server/src/tests/ws.rs index 4829f41a3a..f7c1381d63 100644 --- a/server/src/tests/ws.rs +++ b/server/src/tests/ws.rs @@ -24,9 +24,10 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +use std::sync::atomic::Ordering; use std::time::Duration; -use crate::tests::helpers::{deser_call, init_logger, server_with_context}; +use crate::tests::helpers::{deser_call, init_logger, server_with_context, ws_server_with_stats, Metrics}; use crate::types::SubscriptionId; use crate::{BatchRequestConfig, RegisterMethodError}; use crate::{RpcModule, ServerBuilder}; @@ -874,6 +875,30 @@ async fn drop_client_with_pending_calls_works() { assert!(handle.stopped().with_timeout(MAX_TIMEOUT).await.is_ok()); } +#[tokio::test] +async fn server_notify_on_conn_close() { + init_logger(); + + let metrics = Metrics::default(); + let addr = ws_server_with_stats(metrics.clone()); + + let mut client = WebSocketTestClient::new(addr).with_default_timeout().await.unwrap().unwrap(); + + // Wait for the server to process + tokio::time::sleep(Duration::from_millis(100)).await; + + assert_eq!(metrics.ws_sessions_opened.load(Ordering::SeqCst), 1); + assert_eq!(metrics.ws_sessions_closed.load(Ordering::SeqCst), 0); + + client.close().with_default_timeout().await.unwrap().unwrap(); + + // Wait for the server to process + tokio::time::sleep(Duration::from_millis(100)).await; + + assert_eq!(metrics.ws_sessions_opened.load(Ordering::SeqCst), 1); + assert_eq!(metrics.ws_sessions_closed.load(Ordering::SeqCst), 1); +} + async fn server_with_infinite_call( timeout: Duration, tx: tokio::sync::mpsc::UnboundedSender<()>, diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 11f3bef36c..4d843a51c3 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use std::time::Instant; -use crate::future::IntervalStream; +use crate::future::{IntervalStream, SessionClose}; use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT}; use crate::server::{handle_rpc_call, ConnectionState, ServerConfig}; use crate::{PingConfig, LOG_TARGET}; @@ -55,6 +55,7 @@ pub(crate) struct BackgroundTaskParams { pub(crate) sink: MethodSink, pub(crate) rx: mpsc::Receiver, pub(crate) pending_calls_completed: mpsc::Receiver<()>, + pub(crate) on_session_close: Option, } pub(crate) async fn background_task(params: BackgroundTaskParams) @@ -70,6 +71,7 @@ where sink, rx, pending_calls_completed, + mut on_session_close, } = params; let ServerConfig { ping_config, batch_requests_config, max_request_body_size, max_response_body_size, .. } = server_cfg; @@ -180,6 +182,10 @@ where graceful_shutdown(result, pending_calls_completed, ws_stream, conn_tx, send_task_handle).await; drop(conn); + + if let Some(c) = on_session_close.take() { + c.close(); + } } /// A task that waits for new messages via the `rx channel` and sends them out on the `WebSocket`. @@ -456,6 +462,7 @@ where sink, rx, pending_calls_completed, + on_session_close: None, }; background_task(params).await;