From 6e989d3726f6dae5330415c7a47d3dbc2a1206c9 Mon Sep 17 00:00:00 2001 From: Rahul Subramaniyam <78006270+rahulksnv@users.noreply.github.com> Date: Tue, 27 Dec 2022 10:34:34 -0800 Subject: [PATCH] Resizable semaphore (#1019) * Add ResizableSemaphore * Clean ups * Add unit tests * Use the sem changes * Clean up * Fix race * Remove async lock * Restructure * Address comments * Bring back threshold for bumping quota * Address comments * Fix expect message Co-authored-by: Nazar Mokrynskyi --- crates/subspace-networking/src/create.rs | 121 ++------------ crates/subspace-networking/src/node.rs | 70 ++------ crates/subspace-networking/src/node_runner.rs | 49 ++++-- crates/subspace-networking/src/shared.rs | 23 ++- crates/subspace-networking/src/utils.rs | 151 ++++++++++++++++++ crates/subspace-networking/src/utils/tests.rs | 57 ++++++- 6 files changed, 292 insertions(+), 179 deletions(-) diff --git a/crates/subspace-networking/src/create.rs b/crates/subspace-networking/src/create.rs index ceb6ebe5f6..a5a0218672 100644 --- a/crates/subspace-networking/src/create.rs +++ b/crates/subspace-networking/src/create.rs @@ -1,6 +1,3 @@ -#[cfg(test)] -mod tests; - pub use crate::behavior::custom_record_store::ValueGetter; use crate::behavior::custom_record_store::{ CustomRecordStore, MemoryProviderStorage, NoRecordStorage, @@ -11,7 +8,7 @@ use crate::node::{CircuitRelayClientError, Node}; use crate::node_runner::{NodeRunner, NodeRunnerConfig}; use crate::request_responses::RequestHandler; use crate::shared::Shared; -use crate::utils::convert_multiaddresses; +use crate::utils::{convert_multiaddresses, ResizableSemaphore}; use crate::BootstrappedNetworkingParameters; use futures::channel::mpsc; use libp2p::core::muxing::StreamMuxerBox; @@ -32,13 +29,11 @@ use libp2p::websocket::WsConfig; use libp2p::yamux::YamuxConfig; use libp2p::{core, identity, noise, Multiaddr, PeerId, Transport, TransportError}; use std::num::NonZeroUsize; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; use subspace_core_primitives::{crypto, PIECE_SIZE}; use thiserror::Error; -use tokio::sync::Semaphore; use tracing::{error, info}; const KADEMLIA_PROTOCOL: &[u8] = b"/subspace/kad/0.1.0"; @@ -70,11 +65,11 @@ const YAMUX_MAX_STREAMS: usize = 256; /// /// We restrict this so we don't exceed number of incoming streams for single peer, but this value /// will be boosted depending on number of connected peers. -const KADEMLIA_BASE_CONCURRENT_TASKS: usize = 30; +const KADEMLIA_BASE_CONCURRENT_TASKS: NonZeroUsize = NonZeroUsize::new(30).expect("Not zero; qed"); /// Above base limit will be boosted by specified number for every peer connected starting with /// second peer, such that it scaled with network connectivity, but the exact coefficient might need /// to be tweaked in the future. -const KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER: usize = 1; +pub(crate) const KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER: usize = 1; /// Base limit for number of any concurrent tasks except Kademlia. /// /// We configure total number of streams per connection to 256. Here we assume half of them might be @@ -82,70 +77,12 @@ const KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER: usize = 1; /// /// We restrict this so we don't exceed number of streams for single peer, but this value will be /// boosted depending on number of connected peers. -const REGULAR_BASE_CONCURRENT_TASKS: usize = 120 - KADEMLIA_BASE_CONCURRENT_TASKS; +const REGULAR_BASE_CONCURRENT_TASKS: NonZeroUsize = + NonZeroUsize::new(120 - KADEMLIA_BASE_CONCURRENT_TASKS.get()).expect("Not zero; qed"); /// Above base limit will be boosted by specified number for every peer connected starting with /// second peer, such that it scaled with network connectivity, but the exact coefficient might need /// to be tweaked in the future. -const REGULAR_CONCURRENT_TASKS_BOOST_PER_PEER: usize = 2; -/// How many peers should node be connected to before boosting turns on. -/// -/// 1 means boosting starts with second peer. -const CONCURRENT_TASKS_BOOST_PEERS_THRESHOLD: NonZeroUsize = - NonZeroUsize::new(5).expect("Not zero; qed"); -const SEMAPHORE_MAINTENANCE_INTERVAL: Duration = Duration::from_secs(5); - -async fn maintain_semaphore_permits_capacity( - semaphore: &Semaphore, - interval: Duration, - connected_peers_count_weak: Weak, - boost_per_peer: usize, - boost_peers_threshold: NonZeroUsize, -) { - let base_permits = semaphore.available_permits(); - // Total permits technically supported by semaphore - let mut total_permits = base_permits; - // Some permits might be reserved due to number of peers decreasing and will be released back if - // necessary, this is because semaphore supports increasing number of - let mut reserved_permits = Vec::new(); - loop { - let connected_peers_count = match connected_peers_count_weak.upgrade() { - Some(connected_peers_count) => connected_peers_count.load(Ordering::Relaxed), - None => { - return; - } - }; - let expected_total_permits = base_permits - + connected_peers_count.saturating_sub(boost_peers_threshold.get()) * boost_per_peer; - - // Release reserves to match expected number of permits if necessary - while total_permits < expected_total_permits && !reserved_permits.is_empty() { - reserved_permits.pop(); - total_permits += 1; - } - // If reserved permits were not sufficient, add permits to the semaphore directly. - if total_permits < expected_total_permits { - semaphore.add_permits(expected_total_permits - total_permits); - total_permits = expected_total_permits; - } - // Peers disconnected and expected number of permits went down, we need to put some into - // reserve - if total_permits > expected_total_permits { - let to_reserve = total_permits - expected_total_permits; - reserved_permits.reserve(to_reserve); - for _ in 0..to_reserve { - reserved_permits.push( - semaphore - .acquire() - .await - .expect("We never close a semaphore; qed"), - ); - } - total_permits = expected_total_permits; - } - - tokio::time::sleep(interval).await; - } -} +pub(crate) const REGULAR_CONCURRENT_TASKS_BOOST_PER_PEER: usize = 2; /// Defines relay configuration for the Node #[derive(Clone, Debug)] @@ -368,42 +305,16 @@ where // Create final structs let (command_sender, command_receiver) = mpsc::channel(1); - let shared = Arc::new(Shared::new(local_peer_id, command_sender)); - let shared_weak = Arc::downgrade(&shared); - - let kademlia_tasks_semaphore = Arc::new(Semaphore::new(KADEMLIA_BASE_CONCURRENT_TASKS)); - let regular_tasks_semaphore = Arc::new(Semaphore::new(REGULAR_BASE_CONCURRENT_TASKS)); - - tokio::spawn({ - let kademlia_tasks_semaphore = Arc::clone(&kademlia_tasks_semaphore); - let connected_peers_count_weak = Arc::downgrade(&shared.connected_peers_count); - - async move { - maintain_semaphore_permits_capacity( - &kademlia_tasks_semaphore, - SEMAPHORE_MAINTENANCE_INTERVAL, - connected_peers_count_weak, - KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER, - CONCURRENT_TASKS_BOOST_PEERS_THRESHOLD, - ) - .await; - } - }); - tokio::spawn({ - let regular_tasks_semaphore = Arc::clone(®ular_tasks_semaphore); - let connected_peers_count_weak = Arc::downgrade(&shared.connected_peers_count); + let kademlia_tasks_semaphore = ResizableSemaphore::new(KADEMLIA_BASE_CONCURRENT_TASKS); + let regular_tasks_semaphore = ResizableSemaphore::new(REGULAR_BASE_CONCURRENT_TASKS); - async move { - maintain_semaphore_permits_capacity( - ®ular_tasks_semaphore, - SEMAPHORE_MAINTENANCE_INTERVAL, - connected_peers_count_weak, - REGULAR_CONCURRENT_TASKS_BOOST_PER_PEER, - CONCURRENT_TASKS_BOOST_PEERS_THRESHOLD, - ) - .await; - } - }); + let shared = Arc::new(Shared::new( + local_peer_id, + command_sender, + kademlia_tasks_semaphore.clone(), + regular_tasks_semaphore.clone(), + )); + let shared_weak = Arc::downgrade(&shared); let node = Node::new(shared, kademlia_tasks_semaphore, regular_tasks_semaphore); let node_runner = NodeRunner::::new(NodeRunnerConfig:: { diff --git a/crates/subspace-networking/src/node.rs b/crates/subspace-networking/src/node.rs index 25933e2a8d..071c8bb628 100644 --- a/crates/subspace-networking/src/node.rs +++ b/crates/subspace-networking/src/node.rs @@ -1,6 +1,7 @@ use crate::request_handlers::generic_request_handler::GenericRequest; use crate::request_responses; use crate::shared::{Command, CreatedSubscription, Shared}; +use crate::utils::{ResizableSemaphore, ResizableSemaphorePermit}; use bytes::Bytes; use event_listener_primitives::HandlerId; use futures::channel::mpsc::SendError; @@ -16,7 +17,6 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use thiserror::Error; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::sleep; use tracing::{error, trace}; @@ -29,7 +29,7 @@ pub struct TopicSubscription { command_sender: Option>, #[pin] receiver: mpsc::UnboundedReceiver, - _permit: OwnedSemaphorePermit, + _permit: ResizableSemaphorePermit, } impl Stream for TopicSubscription { @@ -257,15 +257,15 @@ impl From for CircuitRelayClientError { #[must_use = "Node doesn't do anything if dropped"] pub struct Node { shared: Arc, - kademlia_tasks_semaphore: Arc, - regular_tasks_semaphore: Arc, + kademlia_tasks_semaphore: ResizableSemaphore, + regular_tasks_semaphore: ResizableSemaphore, } impl Node { pub(crate) fn new( shared: Arc, - kademlia_tasks_semaphore: Arc, - regular_tasks_semaphore: Arc, + kademlia_tasks_semaphore: ResizableSemaphore, + regular_tasks_semaphore: ResizableSemaphore, ) -> Self { Self { shared, @@ -283,12 +283,7 @@ impl Node { &self, key: Multihash, ) -> Result>, GetValueError> { - let permit = self - .kademlia_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let permit = self.kademlia_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = mpsc::unbounded(); self.shared @@ -310,12 +305,7 @@ impl Node { key: Multihash, value: Vec, ) -> Result, PutValueError> { - let permit = self - .kademlia_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let permit = self.kademlia_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = mpsc::unbounded(); self.shared @@ -334,12 +324,7 @@ impl Node { } pub async fn subscribe(&self, topic: Sha256Topic) -> Result { - let permit = self - .regular_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let permit = self.regular_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = oneshot::channel(); self.shared @@ -366,12 +351,7 @@ impl Node { } pub async fn publish(&self, topic: Sha256Topic, message: Vec) -> Result<(), PublishError> { - let _permit = self - .regular_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let _permit = self.regular_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = oneshot::channel(); self.shared @@ -396,14 +376,7 @@ impl Node { where Request: GenericRequest, { - // TODO: Cancelling this method's future will drop the permit, but will not abort the - // request if it is already initiated - let _permit = self - .regular_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let _permit = self.regular_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = oneshot::channel(); let command = Command::GenericRequest { peer_id, @@ -424,12 +397,7 @@ impl Node { &self, key: Multihash, ) -> Result, GetClosestPeersError> { - let permit = self - .kademlia_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let permit = self.kademlia_tasks_semaphore.acquire().await; trace!(?key, "Starting 'GetClosestPeers' request."); let (result_sender, result_receiver) = mpsc::unbounded(); @@ -482,12 +450,7 @@ impl Node { &self, key: Multihash, ) -> Result, AnnounceError> { - let permit = self - .kademlia_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let permit = self.kademlia_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = mpsc::unbounded(); trace!(?key, "Starting 'start_announcing' request."); @@ -529,12 +492,7 @@ impl Node { &self, key: Multihash, ) -> Result, GetProvidersError> { - let permit = self - .kademlia_tasks_semaphore - .clone() - .acquire_owned() - .await - .expect("We never close a semaphore; qed"); + let permit = self.kademlia_tasks_semaphore.acquire().await; let (result_sender, result_receiver) = mpsc::unbounded(); trace!(?key, "Starting 'get_providers' request."); diff --git a/crates/subspace-networking/src/node_runner.rs b/crates/subspace-networking/src/node_runner.rs index 19009c7177..e5fc6d76f3 100644 --- a/crates/subspace-networking/src/node_runner.rs +++ b/crates/subspace-networking/src/node_runner.rs @@ -1,9 +1,12 @@ use crate::behavior::custom_record_store::CustomRecordStore; use crate::behavior::persistent_parameters::NetworkingParametersRegistry; use crate::behavior::{Behavior, Event}; +use crate::create::{ + KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER, REGULAR_CONCURRENT_TASKS_BOOST_PER_PEER, +}; use crate::request_responses::{Event as RequestResponseEvent, IfDisconnected}; use crate::shared::{Command, CreatedSubscription, Shared}; -use crate::utils; +use crate::utils::{is_global_address_or_dns, ResizableSemaphorePermit}; use bytes::Bytes; use futures::channel::mpsc; use futures::future::Fuse; @@ -24,39 +27,45 @@ use nohash_hasher::IntMap; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; +use std::num::NonZeroUsize; use std::pin::Pin; use std::sync::atomic::Ordering; use std::sync::Weak; use std::time::Duration; -use tokio::sync::OwnedSemaphorePermit; use tokio::time::Sleep; use tracing::{debug, error, trace, warn}; +/// How many peers should node be connected to before boosting turns on. +/// +/// 1 means boosting starts with second peer. +const CONCURRENT_TASKS_BOOST_PEERS_THRESHOLD: NonZeroUsize = + NonZeroUsize::new(5).expect("Not zero; qed"); + enum QueryResultSender { Value { sender: mpsc::UnboundedSender>, // Just holding onto permit while data structure is not dropped - _permit: OwnedSemaphorePermit, + _permit: ResizableSemaphorePermit, }, ClosestPeers { sender: mpsc::UnboundedSender, // Just holding onto permit while data structure is not dropped - _permit: OwnedSemaphorePermit, + _permit: ResizableSemaphorePermit, }, Providers { sender: mpsc::UnboundedSender, // Just holding onto permit while data structure is not dropped - _permit: OwnedSemaphorePermit, + _permit: ResizableSemaphorePermit, }, Announce { sender: mpsc::UnboundedSender<()>, // Just holding onto permit while data structure is not dropped - _permit: OwnedSemaphorePermit, + _permit: ResizableSemaphorePermit, }, PutValue { sender: mpsc::UnboundedSender<()>, // Just holding onto permit while data structure is not dropped - _permit: OwnedSemaphorePermit, + _permit: ResizableSemaphorePermit, }, } @@ -315,7 +324,17 @@ where let is_reserved_peer = self.reserved_peers.contains_key(&peer_id); debug!(%peer_id, %is_reserved_peer, "Connection established [{num_established} from peer]"); - shared.connected_peers_count.fetch_add(1, Ordering::SeqCst); + if shared.connected_peers_count.fetch_add(1, Ordering::SeqCst) + >= CONCURRENT_TASKS_BOOST_PEERS_THRESHOLD.get() + { + // The peer count exceeded the threshold, bump up the quota. + shared + .kademlia_tasks_semaphore + .expand(KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER); + shared + .regular_tasks_semaphore + .expand(REGULAR_CONCURRENT_TASKS_BOOST_PER_PEER); + } let (in_connections_number, out_connections_number) = { let network_info = self.swarm.network_info(); @@ -375,7 +394,17 @@ where }; debug!("Connection closed with peer {peer_id} [{num_established} from peer]"); - shared.connected_peers_count.fetch_sub(1, Ordering::SeqCst); + if shared.connected_peers_count.fetch_sub(1, Ordering::SeqCst) + > CONCURRENT_TASKS_BOOST_PEERS_THRESHOLD.get() + { + // The previous peer count was over the threshold, reclaim the quota. + shared + .kademlia_tasks_semaphore + .shrink(KADEMLIA_CONCURRENT_TASKS_BOOST_PER_PEER); + shared + .regular_tasks_semaphore + .shrink(REGULAR_CONCURRENT_TASKS_BOOST_PER_PEER); + } } SwarmEvent::OutgoingConnectionError { peer_id, error } => match error { DialError::Transport(ref addresses) => { @@ -431,7 +460,7 @@ where if kademlia_enabled { for address in info.listen_addrs { if !self.allow_non_global_addresses_in_dht - && !utils::is_global_address_or_dns(&address) + && !is_global_address_or_dns(&address) { trace!( %local_peer_id, diff --git a/crates/subspace-networking/src/shared.rs b/crates/subspace-networking/src/shared.rs index 001a290b1b..7a9be91a0b 100644 --- a/crates/subspace-networking/src/shared.rs +++ b/crates/subspace-networking/src/shared.rs @@ -2,6 +2,7 @@ //! queries, subscriptions, various events and shared information. use crate::request_responses::RequestFailure; +use crate::utils::{ResizableSemaphore, ResizableSemaphorePermit}; use bytes::Bytes; use event_listener_primitives::Bag; use futures::channel::{mpsc, oneshot}; @@ -12,7 +13,6 @@ use libp2p::{Multiaddr, PeerId}; use parking_lot::Mutex; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -use tokio::sync::OwnedSemaphorePermit; #[derive(Debug)] pub(crate) struct CreatedSubscription { @@ -27,13 +27,13 @@ pub(crate) enum Command { GetValue { key: Multihash, result_sender: mpsc::UnboundedSender>, - permit: OwnedSemaphorePermit, + permit: ResizableSemaphorePermit, }, PutValue { key: Multihash, value: Vec, result_sender: mpsc::UnboundedSender<()>, - permit: OwnedSemaphorePermit, + permit: ResizableSemaphorePermit, }, Subscribe { topic: Sha256Topic, @@ -51,7 +51,7 @@ pub(crate) enum Command { GetClosestPeers { key: Multihash, result_sender: mpsc::UnboundedSender, - permit: OwnedSemaphorePermit, + permit: ResizableSemaphorePermit, }, GenericRequest { peer_id: PeerId, @@ -65,7 +65,7 @@ pub(crate) enum Command { StartAnnouncing { key: Multihash, result_sender: mpsc::UnboundedSender<()>, - permit: OwnedSemaphorePermit, + permit: ResizableSemaphorePermit, }, StopAnnouncing { key: Multihash, @@ -74,7 +74,7 @@ pub(crate) enum Command { GetProviders { key: Multihash, result_sender: mpsc::UnboundedSender, - permit: OwnedSemaphorePermit, + permit: ResizableSemaphorePermit, }, } @@ -93,16 +93,25 @@ pub(crate) struct Shared { pub(crate) connected_peers_count: Arc, /// Sender end of the channel for sending commands to the swarm. pub(crate) command_sender: mpsc::Sender, + pub(crate) kademlia_tasks_semaphore: ResizableSemaphore, + pub(crate) regular_tasks_semaphore: ResizableSemaphore, } impl Shared { - pub(crate) fn new(id: PeerId, command_sender: mpsc::Sender) -> Self { + pub(crate) fn new( + id: PeerId, + command_sender: mpsc::Sender, + kademlia_tasks_semaphore: ResizableSemaphore, + regular_tasks_semaphore: ResizableSemaphore, + ) -> Self { Self { handlers: Handlers::default(), id, listeners: Mutex::default(), connected_peers_count: Arc::new(AtomicUsize::new(0)), command_sender, + kademlia_tasks_semaphore, + regular_tasks_semaphore, } } } diff --git a/crates/subspace-networking/src/utils.rs b/crates/subspace-networking/src/utils.rs index 5b829aefb4..a62d7e183c 100644 --- a/crates/subspace-networking/src/utils.rs +++ b/crates/subspace-networking/src/utils.rs @@ -5,8 +5,11 @@ mod tests; use libp2p::multiaddr::Protocol; use libp2p::{Multiaddr, PeerId}; +use parking_lot::Mutex; use std::marker::PhantomData; use std::num::NonZeroUsize; +use std::sync::Arc; +use tokio::sync::Notify; use tracing::warn; /// This test is successful only for global IP addresses and DNS names. @@ -95,3 +98,151 @@ pub(crate) fn convert_multiaddresses(addresses: Vec) -> Vec, + + /// To signal waiters for permits to be available + notify: Notify, +} + +/// The semaphore state. +#[derive(Debug)] +struct SemState { + /// The current capacity + capacity: usize, + + /// The current outstanding permits + usage: usize, +} + +impl SemState { + // Allocates a permit if available. + // Returns true if allocated, false otherwise. + fn alloc_one(&mut self) -> bool { + if self.usage < self.capacity { + self.usage += 1; + true + } else { + false + } + } + + // Returns a free permit to the free pool. + // Returns true if any waiters need to be notified. + fn free_one(&mut self) -> bool { + let prev_is_full = self.is_full(); + if let Some(dec) = self.usage.checked_sub(1) { + self.usage = dec; + } else { + panic!("SemState::free_one(): invalid free, state = {:?}", self); + } + + // Notify if we did a full -> available transition. + prev_is_full && !self.is_full() + } + + // Expands the max capacity by delta. + // Returns true if any waiters need to be notified. + fn expand(&mut self, delta: usize) -> bool { + let prev_is_full = self.is_full(); + self.capacity += delta; + + // Notify if we did a full -> available transition. + prev_is_full && !self.is_full() + } + + // Shrinks the max capacity by delta. + fn shrink(&mut self, delta: usize) { + if let Some(dec) = self.capacity.checked_sub(delta) { + self.capacity = dec; + } else { + panic!("SemState::shrink(): invalid shrink, state = {:?}", self); + } + } + + // Returns true if current usage exceeds capacity + fn is_full(&self) -> bool { + self.usage >= self.capacity + } +} + +/// Semaphore like implementation that allows both shrinking and expanding +/// the max permits. +#[derive(Clone, Debug)] +pub(crate) struct ResizableSemaphore(Arc); + +impl ResizableSemaphore { + pub(crate) fn new(capacity: NonZeroUsize) -> Self { + let shared = SemShared { + state: Mutex::new(SemState { + capacity: capacity.get(), + usage: 0, + }), + notify: Notify::new(), + }; + Self(Arc::new(shared)) + } + + // Acquires a permit. Waits until a permit is available. + pub(crate) async fn acquire(&self) -> ResizableSemaphorePermit { + loop { + let wait = { + let mut state = self.0.state.lock(); + if state.alloc_one() { + None + } else { + // This needs to be done under the lock to avoid race. + Some(self.0.notify.notified()) + } + }; + + match wait { + Some(notified) => notified.await, + None => break, + } + } + ResizableSemaphorePermit(self.0.clone()) + } + + // Acquires a permit, doesn't wait for permits to be available. + // Currently used only for tests. + #[cfg(test)] + pub(crate) fn try_acquire(&self) -> Option { + let mut state = self.0.state.lock(); + if state.alloc_one() { + Some(ResizableSemaphorePermit(self.0.clone())) + } else { + None + } + } + + // Expands the capacity by the specified amount. + pub(crate) fn expand(&self, delta: usize) { + let notify_waiters = self.0.state.lock().expand(delta); + if notify_waiters { + self.0.notify.notify_waiters(); + } + } + + // Shrinks the capacity by the specified amount. + pub(crate) fn shrink(&self, delta: usize) { + self.0.state.lock().shrink(delta) + } +} + +/// The semaphore permit. +#[derive(Clone, Debug)] +pub(crate) struct ResizableSemaphorePermit(Arc); + +impl Drop for ResizableSemaphorePermit { + fn drop(&mut self) { + let notify_waiters = self.0.state.lock().free_one(); + if notify_waiters { + self.0.notify.notify_waiters(); + } + } +} diff --git a/crates/subspace-networking/src/utils/tests.rs b/crates/subspace-networking/src/utils/tests.rs index b65544c0e6..65e610ed05 100644 --- a/crates/subspace-networking/src/utils/tests.rs +++ b/crates/subspace-networking/src/utils/tests.rs @@ -1,4 +1,4 @@ -use super::CollectionBatcher; +use super::{CollectionBatcher, ResizableSemaphore}; use std::num::NonZeroUsize; #[test] @@ -60,3 +60,58 @@ fn test_batching() { assert_eq!(batcher.next_batch(collection.clone()), vec![3, 4, 5, 6]); assert_eq!(batcher.next_batch(collection), vec![7, 1, 2, 3]); } + +#[test] +fn test_resizable_semaphore_alloc() { + // Capacity = 3. We should be able to alloc only three permits. + let sem = ResizableSemaphore::new(NonZeroUsize::new(3).unwrap()); + let _permit_1 = sem.try_acquire().unwrap(); + let _permit_2 = sem.try_acquire().unwrap(); + let _permit_3 = sem.try_acquire().unwrap(); + assert!(sem.try_acquire().is_none()); +} + +#[test] +fn test_resizable_semaphore_expand() { + // Initial capacity = 3. + let sem = ResizableSemaphore::new(NonZeroUsize::new(3).unwrap()); + let _permit_1 = sem.try_acquire().unwrap(); + let _permit_2 = sem.try_acquire().unwrap(); + let _permit_3 = sem.try_acquire().unwrap(); + assert!(sem.try_acquire().is_none()); + + // Increase capacity of semaphore by 2, we should be able to alloc two more permits. + sem.expand(2); + let _permit_4 = sem.try_acquire().unwrap(); + let _permit_5 = sem.try_acquire().unwrap(); + assert!(sem.try_acquire().is_none()); +} + +#[test] +fn test_resizable_semaphore_shrink() { + // Initial capacity = 4, alloc 4 outstanding permits. + let sem = ResizableSemaphore::new(NonZeroUsize::new(4).unwrap()); + let permit_1 = sem.try_acquire().unwrap(); + let permit_2 = sem.try_acquire().unwrap(); + let permit_3 = sem.try_acquire().unwrap(); + let _permit_4 = sem.try_acquire().unwrap(); + assert!(sem.try_acquire().is_none()); + + // Shrink the capacity by 2, new capacity = 2. + sem.shrink(2); + + // Alloc should fail as outstanding permits(4) >= capacity(2). + assert!(sem.try_acquire().is_none()); + + // Free a permit, alloc should fail as outstanding permits(3) >= capacity(2). + std::mem::drop(permit_2); + assert!(sem.try_acquire().is_none()); + + // Free another permit, alloc should fail as outstanding permits(2) >= capacity(2). + std::mem::drop(permit_3); + assert!(sem.try_acquire().is_none()); + + // Free another permit, alloc should succeed as outstanding permits(1) < capacity(2). + std::mem::drop(permit_1); + assert!(sem.try_acquire().is_some()); +}