-
Notifications
You must be signed in to change notification settings - Fork 55.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rust: pinned-init: add
CMutex
example
Add the `CMutex` example from the pinned-init repository. This type is used in the examples instead of the kernel mutex, since that is not available outside of the `kernel` crate. Change the doctests to use `CMutex` instead of the kernel mutex. Signed-off-by: Benno Lossin <[email protected]>
- Loading branch information
Showing
4 changed files
with
468 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#![feature(allocator_api)] | ||
|
||
use core::convert::Infallible; | ||
|
||
#[derive(Debug)] | ||
pub struct AllocError; | ||
|
||
impl From<Infallible> for AllocError { | ||
fn from(_: Infallible) -> Self { | ||
Self | ||
} | ||
} | ||
|
||
impl From<core::alloc::AllocError> for AllocError { | ||
fn from(_: core::alloc::AllocError) -> Self { | ||
Self | ||
} | ||
} | ||
|
||
fn main() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
#![feature(allocator_api)] | ||
use core::{ | ||
cell::Cell, | ||
convert::Infallible, | ||
marker::PhantomPinned, | ||
pin::Pin, | ||
ptr::{self, NonNull}, | ||
}; | ||
|
||
use pinned_init::*; | ||
mod error; | ||
use error::AllocError; | ||
|
||
#[pin_data(PinnedDrop)] | ||
#[repr(C)] | ||
#[derive(Debug)] | ||
pub struct ListHead { | ||
next: Link, | ||
prev: Link, | ||
#[pin] | ||
pin: PhantomPinned, | ||
} | ||
|
||
impl ListHead { | ||
#[inline] | ||
pub fn new() -> impl PinInit<Self, Infallible> { | ||
try_pin_init!(&this in Self { | ||
next: unsafe { Link::new_unchecked(this) }, | ||
prev: unsafe { Link::new_unchecked(this) }, | ||
pin: PhantomPinned, | ||
}? Infallible) | ||
} | ||
|
||
#[inline] | ||
pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { | ||
try_pin_init!(&this in Self { | ||
prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}), | ||
next: list.next.replace(unsafe { Link::new_unchecked(this)}), | ||
pin: PhantomPinned, | ||
}? Infallible) | ||
} | ||
|
||
#[inline] | ||
pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { | ||
try_pin_init!(&this in Self { | ||
next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}), | ||
prev: list.prev.replace(unsafe { Link::new_unchecked(this)}), | ||
pin: PhantomPinned, | ||
}? Infallible) | ||
} | ||
|
||
#[inline] | ||
pub fn next(&self) -> Option<NonNull<Self>> { | ||
if ptr::eq(self.next.as_ptr(), self) { | ||
None | ||
} else { | ||
Some(unsafe { NonNull::new_unchecked(self.next.as_ptr() as *mut Self) }) | ||
} | ||
} | ||
|
||
#[allow(dead_code)] | ||
pub fn size(&self) -> usize { | ||
let mut size = 1; | ||
let mut cur = self.next.clone(); | ||
while !ptr::eq(self, cur.cur()) { | ||
cur = cur.next().clone(); | ||
size += 1; | ||
} | ||
size | ||
} | ||
} | ||
|
||
#[pinned_drop] | ||
impl PinnedDrop for ListHead { | ||
//#[inline] | ||
fn drop(self: Pin<&mut Self>) { | ||
if !ptr::eq(self.next.as_ptr(), &*self) { | ||
let next = unsafe { &*self.next.as_ptr() }; | ||
let prev = unsafe { &*self.prev.as_ptr() }; | ||
next.prev.set(&self.prev); | ||
prev.next.set(&self.next); | ||
} | ||
} | ||
} | ||
|
||
#[repr(transparent)] | ||
#[derive(Clone, Debug)] | ||
struct Link(Cell<NonNull<ListHead>>); | ||
|
||
impl Link { | ||
#[inline] | ||
unsafe fn new_unchecked(ptr: NonNull<ListHead>) -> Self { | ||
Self(Cell::new(ptr)) | ||
} | ||
|
||
#[inline] | ||
fn next(&self) -> &Link { | ||
unsafe { &(*self.0.get().as_ptr()).next } | ||
} | ||
|
||
#[inline] | ||
fn prev(&self) -> &Link { | ||
unsafe { &(*self.0.get().as_ptr()).prev } | ||
} | ||
|
||
#[allow(dead_code)] | ||
fn cur(&self) -> &ListHead { | ||
unsafe { &*self.0.get().as_ptr() } | ||
} | ||
|
||
#[inline] | ||
fn replace(&self, other: Link) -> Link { | ||
unsafe { Link::new_unchecked(self.0.replace(other.0.get())) } | ||
} | ||
|
||
#[inline] | ||
fn as_ptr(&self) -> *const ListHead { | ||
self.0.get().as_ptr() | ||
} | ||
|
||
#[inline] | ||
fn set(&self, val: &Link) { | ||
self.0.set(val.0.get()); | ||
} | ||
} | ||
|
||
#[allow(dead_code)] | ||
#[cfg_attr(test, test)] | ||
fn main() -> Result<(), AllocError> { | ||
let a = Box::pin_init(ListHead::new())?; | ||
stack_pin_init!(let b = ListHead::insert_next(&a)); | ||
stack_pin_init!(let c = ListHead::insert_next(&a)); | ||
stack_pin_init!(let d = ListHead::insert_next(&b)); | ||
let e = Box::pin_init(ListHead::insert_next(&b))?; | ||
println!("a ({a:p}): {a:?}"); | ||
println!("b ({b:p}): {b:?}"); | ||
println!("c ({c:p}): {c:?}"); | ||
println!("d ({d:p}): {d:?}"); | ||
println!("e ({e:p}): {e:?}"); | ||
let mut inspect = &*a; | ||
while let Some(next) = inspect.next() { | ||
println!("({inspect:p}): {inspect:?}"); | ||
inspect = unsafe { &*next.as_ptr() }; | ||
if core::ptr::eq(inspect, &*a) { | ||
break; | ||
} | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
#![feature(allocator_api)] | ||
use core::{ | ||
cell::{Cell, UnsafeCell}, | ||
marker::PhantomPinned, | ||
ops::{Deref, DerefMut}, | ||
pin::Pin, | ||
sync::atomic::{AtomicBool, Ordering}, | ||
}; | ||
use std::{ | ||
sync::Arc, | ||
thread::{self, park, sleep, Builder, Thread}, | ||
time::Duration, | ||
}; | ||
|
||
use pinned_init::*; | ||
#[allow(unused_attributes)] | ||
#[path = "./linked_list.rs"] | ||
pub mod linked_list; | ||
use linked_list::*; | ||
|
||
pub struct SpinLock { | ||
inner: AtomicBool, | ||
} | ||
|
||
impl SpinLock { | ||
#[inline] | ||
pub fn acquire(&self) -> SpinLockGuard<'_> { | ||
while self | ||
.inner | ||
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) | ||
.is_err() | ||
{ | ||
while self.inner.load(Ordering::Relaxed) { | ||
thread::yield_now(); | ||
} | ||
} | ||
SpinLockGuard(self) | ||
} | ||
|
||
#[inline] | ||
pub const fn new() -> Self { | ||
Self { | ||
inner: AtomicBool::new(false), | ||
} | ||
} | ||
} | ||
|
||
pub struct SpinLockGuard<'a>(&'a SpinLock); | ||
|
||
impl Drop for SpinLockGuard<'_> { | ||
#[inline] | ||
fn drop(&mut self) { | ||
self.0.inner.store(false, Ordering::Release); | ||
} | ||
} | ||
|
||
#[pin_data] | ||
pub struct CMutex<T> { | ||
#[pin] | ||
wait_list: ListHead, | ||
spin_lock: SpinLock, | ||
locked: Cell<bool>, | ||
#[pin] | ||
data: UnsafeCell<T>, | ||
} | ||
|
||
impl<T> CMutex<T> { | ||
#[inline] | ||
pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { | ||
pin_init!(CMutex { | ||
wait_list <- ListHead::new(), | ||
spin_lock: SpinLock::new(), | ||
locked: Cell::new(false), | ||
data <- unsafe { | ||
pin_init_from_closure(|slot: *mut UnsafeCell<T>| val.__pinned_init(slot.cast::<T>())) | ||
}, | ||
}) | ||
} | ||
|
||
#[inline] | ||
pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { | ||
let mut sguard = self.spin_lock.acquire(); | ||
if self.locked.get() { | ||
stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); | ||
// println!("wait list length: {}", self.wait_list.size()); | ||
while self.locked.get() { | ||
drop(sguard); | ||
park(); | ||
sguard = self.spin_lock.acquire(); | ||
} | ||
drop(wait_entry); | ||
} | ||
self.locked.set(true); | ||
unsafe { | ||
Pin::new_unchecked(CMutexGuard { | ||
mtx: self, | ||
_pin: PhantomPinned, | ||
}) | ||
} | ||
} | ||
|
||
pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { | ||
// SAFETY: we have an exclusive reference and thus nobody has access to data. | ||
unsafe { &mut *self.data.get() } | ||
} | ||
} | ||
|
||
unsafe impl<T: Send> Send for CMutex<T> {} | ||
unsafe impl<T: Send> Sync for CMutex<T> {} | ||
|
||
pub struct CMutexGuard<'a, T> { | ||
mtx: &'a CMutex<T>, | ||
_pin: PhantomPinned, | ||
} | ||
|
||
impl<'a, T> Drop for CMutexGuard<'a, T> { | ||
#[inline] | ||
fn drop(&mut self) { | ||
let sguard = self.mtx.spin_lock.acquire(); | ||
self.mtx.locked.set(false); | ||
if let Some(list_field) = self.mtx.wait_list.next() { | ||
let wait_entry = list_field.as_ptr().cast::<WaitEntry>(); | ||
unsafe { (*wait_entry).thread.unpark() }; | ||
} | ||
drop(sguard); | ||
} | ||
} | ||
|
||
impl<'a, T> Deref for CMutexGuard<'a, T> { | ||
type Target = T; | ||
|
||
#[inline] | ||
fn deref(&self) -> &Self::Target { | ||
unsafe { &*self.mtx.data.get() } | ||
} | ||
} | ||
|
||
impl<'a, T> DerefMut for CMutexGuard<'a, T> { | ||
#[inline] | ||
fn deref_mut(&mut self) -> &mut Self::Target { | ||
unsafe { &mut *self.mtx.data.get() } | ||
} | ||
} | ||
|
||
#[pin_data] | ||
#[repr(C)] | ||
struct WaitEntry { | ||
#[pin] | ||
wait_list: ListHead, | ||
thread: Thread, | ||
} | ||
|
||
impl WaitEntry { | ||
#[inline] | ||
fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { | ||
pin_init!(Self { | ||
thread: thread::current(), | ||
wait_list <- ListHead::insert_prev(list), | ||
}) | ||
} | ||
} | ||
|
||
#[allow(dead_code)] | ||
#[cfg_attr(test, test)] | ||
fn main() { | ||
let mtx: Pin<Arc<CMutex<usize>>> = | ||
<Arc<_> as InPlaceInit<_>>::pin_init(CMutex::new(0)).unwrap(); | ||
let mut handles = vec![]; | ||
let thread_count = 20; | ||
let workload = if cfg!(miri) { 100 } else { 1_000_000 }; | ||
for i in 0..thread_count { | ||
let mtx = mtx.clone(); | ||
handles.push( | ||
Builder::new() | ||
.name(format!("worker #{i}")) | ||
.spawn(move || { | ||
for _ in 0..workload { | ||
*mtx.lock() += 1; | ||
} | ||
println!("{i} halfway"); | ||
sleep(Duration::from_millis((i as u64) * 10)); | ||
for _ in 0..workload { | ||
*mtx.lock() += 1; | ||
} | ||
println!("{i} finished"); | ||
}) | ||
.expect("should not fail"), | ||
); | ||
} | ||
for h in handles { | ||
h.join().expect("thread paniced"); | ||
} | ||
println!("{:?}", &*mtx.lock()); | ||
assert_eq!(*mtx.lock(), workload * thread_count * 2); | ||
} |
Oops, something went wrong.