Skip to content

Commit

Permalink
feat: add a way to limit the number of subscriptions per connection (#…
Browse files Browse the repository at this point in the history
…739)

* feat: limit the number of subscriptions

Closing #729

* fix nit

* Update core/src/server/helpers.rs

* add integration tests + some fixes so it works

* cargo fmt

* fix doc links

* Unsubscribe calls should avoid subscription limits

Point to Tokio 1.16 (we use a method from it), and a little special treatment for unsubscribe methods

* No resource limiting for Unsubscribe calls

* Test that we can still unsubscribe after hitting a limit

* Fix a comment typo

Co-authored-by: Alexandru Vasile <[email protected]>

* Update core/src/server/rpc_module.rs

* Update core/src/server/rpc_module.rs

Co-authored-by: James Wilson <[email protected]>
Co-authored-by: Alexandru Vasile <[email protected]>
  • Loading branch information
3 people authored May 3, 2022
1 parent 661870a commit 816ecca
Show file tree
Hide file tree
Showing 15 changed files with 269 additions and 46 deletions.
2 changes: 1 addition & 1 deletion benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jsonrpc-http-server = { version = "18.0.0", optional = true }
jsonrpc-pubsub = { version = "18.0.0", optional = true }
num_cpus = "1"
serde_json = "1"
tokio = { version = "1.8", features = ["rt-multi-thread"] }
tokio = { version = "1.16", features = ["rt-multi-thread"] }

[[bench]]
name = "bench"
Expand Down
4 changes: 2 additions & 2 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ jsonrpsee-core = { path = "../../core", version = "0.11.0", features = ["client"
serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.8", features = ["time"] }
tokio = { version = "1.16", features = ["time"] }
tracing = "0.1"

[dev-dependencies]
jsonrpsee-test-utils = { path = "../../test-utils" }
tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros"] }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros"] }

[features]
default = ["tls"]
Expand Down
4 changes: 2 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ rustc-hash = { version = "1", optional = true }
rand = { version = "0.8", optional = true }
soketto = { version = "0.7.1", optional = true }
parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.8", optional = true }
tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-timer = { version = "3", optional = true }

Expand Down Expand Up @@ -66,5 +66,5 @@ async-wasm-client = [

[dev-dependencies]
serde_json = "1.0"
tokio = { version = "1.8", features = ["macros", "rt"] }
tokio = { version = "1.16", features = ["macros", "rt"] }
jsonrpsee = { path = "../jsonrpsee", features = ["server", "macros"] }
5 changes: 5 additions & 0 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ impl<Notif> Subscription<Notif> {
) -> Self {
Self { to_back, notifs_rx, kind, marker: PhantomData }
}

/// Return the subscription type and, if applicable, ID.
pub fn kind(&self) -> &SubscriptionKind {
&self.kind
}
}

/// Batch request message.
Expand Down
61 changes: 61 additions & 0 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
// DEALINGS IN THE SOFTWARE.

use std::io;
use std::sync::Arc;

use crate::{to_json_raw_value, Error};
use futures_channel::mpsc;
use futures_util::StreamExt;
use jsonrpsee_types::error::{ErrorCode, ErrorObject, ErrorResponse, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG};
use jsonrpsee_types::{Id, InvalidRequest, Response};
use serde::Serialize;
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};

/// Bounded writer that allows writing at most `max_len` bytes.
///
Expand Down Expand Up @@ -196,8 +198,53 @@ pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver<String>) -> Stri
buf
}

/// A permitted subscription.
#[derive(Debug)]
pub struct SubscriptionPermit {
_permit: OwnedSemaphorePermit,
resource: Arc<Notify>,
}

impl SubscriptionPermit {
/// Get the handle to [`tokio::sync::Notify`].
pub fn handle(&self) -> Arc<Notify> {
self.resource.clone()
}
}

/// Wrapper over [`tokio::sync::Notify`] with bounds check.
#[derive(Debug, Clone)]
pub struct BoundedSubscriptions {
resource: Arc<Notify>,
guard: Arc<Semaphore>,
}

impl BoundedSubscriptions {
/// Create a new bounded subscription.
pub fn new(max_subscriptions: u32) -> Self {
Self { resource: Arc::new(Notify::new()), guard: Arc::new(Semaphore::new(max_subscriptions as usize)) }
}

/// Attempts to acquire a subscription slot.
///
/// Fails if `max_subscriptions` have been exceeded.
pub fn acquire(&self) -> Option<SubscriptionPermit> {
Arc::clone(&self.guard)
.try_acquire_owned()
.ok()
.map(|p| SubscriptionPermit { _permit: p, resource: self.resource.clone() })
}

/// Close all subscriptions.
pub fn close(&self) {
self.resource.notify_waiters();
}
}

#[cfg(test)]
mod tests {
use crate::server::helpers::BoundedSubscriptions;

use super::{BoundedWriter, Id, Response};

#[test]
Expand All @@ -215,4 +262,18 @@ mod tests {
// NOTE: `"` is part of the serialization so 101 characters.
assert!(serde_json::to_writer(&mut writer, &"x".repeat(99)).is_err());
}

#[test]
fn bounded_subscriptions_work() {
let subs = BoundedSubscriptions::new(5);
let mut handles = Vec::new();

for _ in 0..5 {
handles.push(subs.acquire().unwrap());
}

assert!(subs.acquire().is_none());
handles.swap_remove(0);
assert!(subs.acquire().is_some());
}
}
47 changes: 30 additions & 17 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use std::sync::Arc;

use crate::error::{Error, SubscriptionClosed};
use crate::id_providers::RandomIntegerIdProvider;
use crate::server::helpers::MethodSink;
use crate::server::helpers::{BoundedSubscriptions, MethodSink, SubscriptionPermit};
use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources};
use crate::traits::{IdProvider, ToRpcParams};
use futures_channel::mpsc;
Expand All @@ -48,7 +48,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::{watch, Notify};
use tokio::sync::watch;

/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
Expand All @@ -61,6 +61,8 @@ pub type AsyncMethod<'a> = Arc<
>;
/// Method callback for subscriptions.
pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnState) -> bool>;
// Method callback to unsubscribe.
type UnsubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId) -> bool>;

/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
Expand All @@ -70,15 +72,15 @@ pub type ConnectionId = usize;
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`tokio::sync::Notify`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, Arc<Notify>);
/// - a [`crate::server::helpers::SubscriptionPermit`] to allow subscribers to notify their [`SubscriptionSink`] when they disconnect.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, SubscriptionPermit);

/// Helper struct to manage subscriptions.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Get notified when the connection to subscribers is closed.
pub close_notify: Arc<Notify>,
pub close_notify: SubscriptionPermit,
/// ID provider.
pub id_provider: &'a dyn IdProvider,
}
Expand Down Expand Up @@ -114,8 +116,10 @@ pub enum MethodKind {
Sync(SyncMethod),
/// Asynchronous method handler.
Async(AsyncMethod<'static>),
/// Subscription method handler
/// Subscription method handler.
Subscription(SubscriptionMethod),
/// Unsubscription method handler.
Unsubscription(UnsubscriptionMethod),
}

/// Information about resources the method uses during its execution. Initialized when the the server starts.
Expand Down Expand Up @@ -189,6 +193,13 @@ impl MethodCallback {
}
}

fn new_unsubscription(callback: UnsubscriptionMethod) -> Self {
MethodCallback {
callback: MethodKind::Unsubscription(callback),
resources: MethodResources::Uninitialized([].into()),
}
}

/// Attempt to claim resources prior to executing a method. On success returns a guard that releases
/// claimed resources when dropped.
pub fn claim(&self, name: &str, resources: &Resources) -> Result<ResourceGuard, Error> {
Expand All @@ -210,6 +221,7 @@ impl Debug for MethodKind {
Self::Async(_) => write!(f, "Async"),
Self::Sync(_) => write!(f, "Sync"),
Self::Subscription(_) => write!(f, "Subscription"),
Self::Unsubscription(_) => write!(f, "Unsubscription"),
}
}
}
Expand Down Expand Up @@ -393,17 +405,19 @@ impl Methods {
let sink = MethodSink::new(tx_sink);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let notify = Arc::new(Notify::new());
let bounded_subs = BoundedSubscriptions::new(u32::MAX);
let close_notify = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");
let notify = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");

let _result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let close_notify = notify.clone();
let conn_state = ConnState { conn_id: 0, close_notify, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
}
Some(MethodKind::Unsubscription(cb)) => (cb)(id, params, &sink, 0),
};

let resp = rx_sink.next().await.expect("tx and rx still alive; qed");
Expand Down Expand Up @@ -707,7 +721,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::new_subscription(Arc::new(move |id, params, sink, conn| {
MethodCallback::new_unsubscription(Arc::new(move |id, params, sink, conn_id| {
let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
Expand All @@ -722,8 +736,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
};
let sub_id = sub_id.into_owned();

let result =
subscribers.lock().remove(&SubscriptionKey { conn_id: conn.conn_id, sub_id }).is_some();
let result = subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }).is_some();

sink.send_response(id, result)
})),
Expand Down Expand Up @@ -757,7 +770,7 @@ struct InnerPendingSubscription {
/// Sink.
sink: MethodSink,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
close_notify: Option<SubscriptionPermit>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -819,7 +832,7 @@ pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Get notified when subscribers leave so we can exit
close_notify: Option<Arc<Notify>>,
close_notify: Option<SubscriptionPermit>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
Expand Down Expand Up @@ -892,8 +905,8 @@ impl SubscriptionSink {
T: Serialize,
E: std::fmt::Display,
{
let conn_closed = match self.close_notify.clone() {
Some(close_notify) => close_notify,
let conn_closed = match self.close_notify.as_ref().map(|cn| cn.handle()) {
Some(cn) => cn,
None => {
return SubscriptionClosed::RemotePeerAborted;
}
Expand Down Expand Up @@ -1035,7 +1048,7 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
close_notify: Option<Arc<Notify>>,
close_notify: Option<SubscriptionPermit>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}
Expand All @@ -1045,7 +1058,7 @@ impl Subscription {
pub fn close(&mut self) {
tracing::trace!("[Subscription::close] Notifying");
if let Some(n) = self.close_notify.take() {
n.notify_one()
n.handle().notify_one()
}
}
/// Get the subscription ID
Expand Down
2 changes: 1 addition & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ futures = "0.3"
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
tokio = { version = "1.8", features = ["full"] }
tokio = { version = "1.16", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
serde_json = { version = "1" }

Expand Down
2 changes: 1 addition & 1 deletion http-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ globset = "0.4"
lazy_static = "1.4"
tracing = "0.1"
serde_json = "1"
tokio = { version = "1.8", features = ["rt-multi-thread", "macros"] }
tokio = { version = "1.16", features = ["rt-multi-thread", "macros"] }
unicase = "2.6.0"

[dev-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ async fn process_validated_request(
false
}
},
MethodKind::Subscription(_) => {
MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => {
tracing::error!("Subscriptions not supported on HTTP");
sink.send_error(req.id, ErrorCode::InternalError.into());
false
Expand Down Expand Up @@ -622,7 +622,7 @@ async fn process_validated_request(
None
}
},
MethodKind::Subscription(_) => {
MethodKind::Subscription(_) | MethodKind::Unsubscription(_) => {
tracing::error!("Subscriptions not supported on HTTP");
sink.send_error(req.id, ErrorCode::InternalError.into());
middleware.on_result(&req.method, false, request_start);
Expand Down
2 changes: 1 addition & 1 deletion proc-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ proc-macro-crate = "1"
[dev-dependencies]
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
trybuild = "1.0"
tokio = { version = "1.8", features = ["rt", "macros"] }
tokio = { version = "1.16", features = ["rt", "macros"] }
futures-channel = { version = "0.3.14", default-features = false }
futures-util = { version = "0.3.14", default-features = false }
2 changes: 1 addition & 1 deletion test-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ tracing = "0.1"
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = "1"
soketto = { version = "0.7.1", features = ["http"] }
tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
2 changes: 1 addition & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ env_logger = "0.9"
beef = { version = "0.5.1", features = ["impl_serde"] }
futures = { version = "0.3.14", default-features = false, features = ["std"] }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
tokio = { version = "1.8", features = ["full"] }
tokio = { version = "1.16", features = ["full"] }
tracing = "0.1"
serde = "1"
serde_json = "1"
Expand Down
Loading

0 comments on commit 816ecca

Please sign in to comment.