From 1a1ab58ae43a7253c4cd5955b99fcc3568a022bc Mon Sep 17 00:00:00 2001 From: Peter Gerber Date: Tue, 2 Jul 2024 19:26:05 +0000 Subject: [PATCH] Add is_locked_serially() to check if we are in a #[serial] context --- serial_test/src/code_lock.rs | 124 +++++++++++++++++++++++++++++++++++ serial_test/src/lib.rs | 2 + serial_test/src/rwlock.rs | 4 ++ 3 files changed, 130 insertions(+) diff --git a/serial_test/src/code_lock.rs b/serial_test/src/code_lock.rs index cbb2161..168823e 100644 --- a/serial_test/src/code_lock.rs +++ b/serial_test/src/code_lock.rs @@ -34,6 +34,10 @@ impl UniqueReentrantMutex { pub fn is_locked(&self) -> bool { self.locks.is_locked() } + + pub fn is_locked_by_current_thread(&self) -> bool { + self.locks.is_locked_by_current_thread() + } } #[inline] @@ -44,6 +48,63 @@ pub(crate) fn global_locks() -> &'static HashMap { LOCKS.get_or_init(HashMap::new) } +/// Check if the current thread is holding a serial lock +/// +/// Can be used to assert that a piece of code can only be called +/// from a test marked `#[serial]`. +/// +/// Example, with `#[serial]`: +/// +/// ``` +/// use serial_test::{is_locked_serially, serial}; +/// +/// fn do_something_in_need_of_serialization() { +/// assert!(is_locked_serially(None)); +/// +/// // ... +/// } +/// +/// #[test] +/// # fn unused() {} +/// #[serial] +/// fn main() { +/// do_something_in_need_of_serialization(); +/// } +/// ``` +/// +/// Example, missing `#[serial]`: +/// +/// ```should_panic +/// use serial_test::{is_locked_serially, serial}; +/// +/// #[test] +/// # fn unused() {} +/// // #[serial] // <-- missing +/// fn main() { +/// assert!(is_locked_serially(None)); +/// } +/// ``` +/// +/// Example, `#[test(some_key)]`: +/// +/// ``` +/// use serial_test::{is_locked_serially, serial}; +/// +/// #[test] +/// # fn unused() {} +/// #[serial(some_key)] +/// fn main() { +/// assert!(is_locked_serially(Some("some_key"))); +/// assert!(!is_locked_serially(None)); +/// } +/// ``` +pub fn is_locked_serially(name: Option<&str>) -> bool { + global_locks() + .get(name.unwrap_or_default()) + .map(|lock| lock.get().is_locked_by_current_thread()) + .unwrap_or_default() +} + static MUTEX_ID: AtomicU32 = AtomicU32::new(1); impl UniqueReentrantMutex { @@ -68,3 +129,66 @@ pub(crate) fn check_new_key(name: &str) { Entry::Vacant(v) => v.insert_entry(UniqueReentrantMutex::new_mutex(name)), }; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{local_parallel_core, local_serial_core}; + + const NAME1: &str = "NAME1"; + const NAME2: &str = "NAME2"; + + #[test] + fn assert_serially_locked_without_name() { + local_serial_core(vec![""], None, || { + assert!(is_locked_serially(None)); + assert!(!is_locked_serially(Some("no_such_name"))); + }); + } + + #[test] + fn assert_serially_locked_with_multiple_names() { + local_serial_core(vec![NAME1, NAME2], None, || { + assert!(is_locked_serially(Some(NAME1))); + assert!(is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(Some("no_such_name"))); + assert!(!is_locked_serially(None)); + }); + } + + #[test] + fn assert_serially_locked_when_actually_locked_parallel() { + local_parallel_core(vec![NAME1, NAME2], None, || { + assert!(!is_locked_serially(Some(NAME1))); + assert!(!is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(Some("no_such_name"))); + assert!(!is_locked_serially(None)); + }); + } + + #[test] + fn assert_serially_locked_outside_serial_lock() { + assert!(!is_locked_serially(Some(NAME1))); + assert!(!is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(None)); + + local_serial_core(vec![NAME1], None, || { + // ... + }); + + assert!(!is_locked_serially(Some(NAME1))); + assert!(!is_locked_serially(Some(NAME2))); + assert!(!is_locked_serially(None)); + } + + #[test] + fn assert_serially_locked_in_different_thread() { + local_serial_core(vec![NAME1, NAME2], None, || { + std::thread::spawn(|| { + assert!(!is_locked_serially(Some(NAME2))); + }) + .join() + .unwrap(); + }); + } +} diff --git a/serial_test/src/lib.rs b/serial_test/src/lib.rs index d6519f0..a5d71c0 100644 --- a/serial_test/src/lib.rs +++ b/serial_test/src/lib.rs @@ -112,3 +112,5 @@ pub use serial_test_derive::{parallel, serial}; #[cfg(feature = "file_locks")] pub use serial_test_derive::{file_parallel, file_serial}; + +pub use code_lock::is_locked_serially; diff --git a/serial_test/src/rwlock.rs b/serial_test/src/rwlock.rs index 0be0c90..fe8f4b6 100644 --- a/serial_test/src/rwlock.rs +++ b/serial_test/src/rwlock.rs @@ -54,6 +54,10 @@ impl Locks { self.arc.serial.is_locked() } + pub fn is_locked_by_current_thread(&self) -> bool { + self.arc.serial.is_owned_by_current_thread() + } + pub fn serial(&self) -> MutexGuardWrapper { #[cfg(feature = "logging")] debug!("Get serial lock '{}'", self.name);