Skip to content

Commit

Permalink
rust: pinned-init: add CMutex example
Browse files Browse the repository at this point in the history
Add the `CMutex` example from the pinned-init repository. This type is
used in the examples instead of the kernel mutex, since that is not
available outside of the `kernel` crate.
Change the doctests to use `CMutex` instead of the kernel mutex.

Signed-off-by: Benno Lossin <[email protected]>
  • Loading branch information
y86-dev committed Mar 14, 2024
1 parent 7289166 commit 2900f80
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 39 deletions.
20 changes: 20 additions & 0 deletions rust/pinned_init/examples/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#![feature(allocator_api)]

use core::convert::Infallible;

#[derive(Debug)]
pub struct AllocError;

impl From<Infallible> for AllocError {
fn from(_: Infallible) -> Self {
Self
}
}

impl From<core::alloc::AllocError> for AllocError {
fn from(_: core::alloc::AllocError) -> Self {
Self
}
}

fn main() {}
149 changes: 149 additions & 0 deletions rust/pinned_init/examples/linked_list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#![feature(allocator_api)]
use core::{
cell::Cell,
convert::Infallible,
marker::PhantomPinned,
pin::Pin,
ptr::{self, NonNull},
};

use pinned_init::*;
mod error;
use error::AllocError;

#[pin_data(PinnedDrop)]
#[repr(C)]
#[derive(Debug)]
pub struct ListHead {
next: Link,
prev: Link,
#[pin]
pin: PhantomPinned,
}

impl ListHead {
#[inline]
pub fn new() -> impl PinInit<Self, Infallible> {
try_pin_init!(&this in Self {
next: unsafe { Link::new_unchecked(this) },
prev: unsafe { Link::new_unchecked(this) },
pin: PhantomPinned,
}? Infallible)
}

#[inline]
pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ {
try_pin_init!(&this in Self {
prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}),
next: list.next.replace(unsafe { Link::new_unchecked(this)}),
pin: PhantomPinned,
}? Infallible)
}

#[inline]
pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ {
try_pin_init!(&this in Self {
next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}),
prev: list.prev.replace(unsafe { Link::new_unchecked(this)}),
pin: PhantomPinned,
}? Infallible)
}

#[inline]
pub fn next(&self) -> Option<NonNull<Self>> {
if ptr::eq(self.next.as_ptr(), self) {
None
} else {
Some(unsafe { NonNull::new_unchecked(self.next.as_ptr() as *mut Self) })
}
}

#[allow(dead_code)]
pub fn size(&self) -> usize {
let mut size = 1;
let mut cur = self.next.clone();
while !ptr::eq(self, cur.cur()) {
cur = cur.next().clone();
size += 1;
}
size
}
}

#[pinned_drop]
impl PinnedDrop for ListHead {
//#[inline]
fn drop(self: Pin<&mut Self>) {
if !ptr::eq(self.next.as_ptr(), &*self) {
let next = unsafe { &*self.next.as_ptr() };
let prev = unsafe { &*self.prev.as_ptr() };
next.prev.set(&self.prev);
prev.next.set(&self.next);
}
}
}

#[repr(transparent)]
#[derive(Clone, Debug)]
struct Link(Cell<NonNull<ListHead>>);

impl Link {
#[inline]
unsafe fn new_unchecked(ptr: NonNull<ListHead>) -> Self {
Self(Cell::new(ptr))
}

#[inline]
fn next(&self) -> &Link {
unsafe { &(*self.0.get().as_ptr()).next }
}

#[inline]
fn prev(&self) -> &Link {
unsafe { &(*self.0.get().as_ptr()).prev }
}

#[allow(dead_code)]
fn cur(&self) -> &ListHead {
unsafe { &*self.0.get().as_ptr() }
}

#[inline]
fn replace(&self, other: Link) -> Link {
unsafe { Link::new_unchecked(self.0.replace(other.0.get())) }
}

#[inline]
fn as_ptr(&self) -> *const ListHead {
self.0.get().as_ptr()
}

#[inline]
fn set(&self, val: &Link) {
self.0.set(val.0.get());
}
}

#[allow(dead_code)]
#[cfg_attr(test, test)]
fn main() -> Result<(), AllocError> {
let a = Box::pin_init(ListHead::new())?;
stack_pin_init!(let b = ListHead::insert_next(&a));
stack_pin_init!(let c = ListHead::insert_next(&a));
stack_pin_init!(let d = ListHead::insert_next(&b));
let e = Box::pin_init(ListHead::insert_next(&b))?;
println!("a ({a:p}): {a:?}");
println!("b ({b:p}): {b:?}");
println!("c ({c:p}): {c:?}");
println!("d ({d:p}): {d:?}");
println!("e ({e:p}): {e:?}");
let mut inspect = &*a;
while let Some(next) = inspect.next() {
println!("({inspect:p}): {inspect:?}");
inspect = unsafe { &*next.as_ptr() };
if core::ptr::eq(inspect, &*a) {
break;
}
}
Ok(())
}
195 changes: 195 additions & 0 deletions rust/pinned_init/examples/mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#![feature(allocator_api)]
use core::{
cell::{Cell, UnsafeCell},
marker::PhantomPinned,
ops::{Deref, DerefMut},
pin::Pin,
sync::atomic::{AtomicBool, Ordering},
};
use std::{
sync::Arc,
thread::{self, park, sleep, Builder, Thread},
time::Duration,
};

use pinned_init::*;
#[allow(unused_attributes)]
#[path = "./linked_list.rs"]
pub mod linked_list;
use linked_list::*;

pub struct SpinLock {
inner: AtomicBool,
}

impl SpinLock {
#[inline]
pub fn acquire(&self) -> SpinLockGuard<'_> {
while self
.inner
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
while self.inner.load(Ordering::Relaxed) {
thread::yield_now();
}
}
SpinLockGuard(self)
}

#[inline]
pub const fn new() -> Self {
Self {
inner: AtomicBool::new(false),
}
}
}

pub struct SpinLockGuard<'a>(&'a SpinLock);

impl Drop for SpinLockGuard<'_> {
#[inline]
fn drop(&mut self) {
self.0.inner.store(false, Ordering::Release);
}
}

#[pin_data]
pub struct CMutex<T> {
#[pin]
wait_list: ListHead,
spin_lock: SpinLock,
locked: Cell<bool>,
#[pin]
data: UnsafeCell<T>,
}

impl<T> CMutex<T> {
#[inline]
pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
pin_init!(CMutex {
wait_list <- ListHead::new(),
spin_lock: SpinLock::new(),
locked: Cell::new(false),
data <- unsafe {
pin_init_from_closure(|slot: *mut UnsafeCell<T>| val.__pinned_init(slot.cast::<T>()))
},
})
}

#[inline]
pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
let mut sguard = self.spin_lock.acquire();
if self.locked.get() {
stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
// println!("wait list length: {}", self.wait_list.size());
while self.locked.get() {
drop(sguard);
park();
sguard = self.spin_lock.acquire();
}
drop(wait_entry);
}
self.locked.set(true);
unsafe {
Pin::new_unchecked(CMutexGuard {
mtx: self,
_pin: PhantomPinned,
})
}
}

pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
// SAFETY: we have an exclusive reference and thus nobody has access to data.
unsafe { &mut *self.data.get() }
}
}

unsafe impl<T: Send> Send for CMutex<T> {}
unsafe impl<T: Send> Sync for CMutex<T> {}

pub struct CMutexGuard<'a, T> {
mtx: &'a CMutex<T>,
_pin: PhantomPinned,
}

impl<'a, T> Drop for CMutexGuard<'a, T> {
#[inline]
fn drop(&mut self) {
let sguard = self.mtx.spin_lock.acquire();
self.mtx.locked.set(false);
if let Some(list_field) = self.mtx.wait_list.next() {
let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
unsafe { (*wait_entry).thread.unpark() };
}
drop(sguard);
}
}

impl<'a, T> Deref for CMutexGuard<'a, T> {
type Target = T;

#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.mtx.data.get() }
}
}

impl<'a, T> DerefMut for CMutexGuard<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mtx.data.get() }
}
}

#[pin_data]
#[repr(C)]
struct WaitEntry {
#[pin]
wait_list: ListHead,
thread: Thread,
}

impl WaitEntry {
#[inline]
fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
pin_init!(Self {
thread: thread::current(),
wait_list <- ListHead::insert_prev(list),
})
}
}

#[allow(dead_code)]
#[cfg_attr(test, test)]
fn main() {
let mtx: Pin<Arc<CMutex<usize>>> =
<Arc<_> as InPlaceInit<_>>::pin_init(CMutex::new(0)).unwrap();
let mut handles = vec![];
let thread_count = 20;
let workload = if cfg!(miri) { 100 } else { 1_000_000 };
for i in 0..thread_count {
let mtx = mtx.clone();
handles.push(
Builder::new()
.name(format!("worker #{i}"))
.spawn(move || {
for _ in 0..workload {
*mtx.lock() += 1;
}
println!("{i} halfway");
sleep(Duration::from_millis((i as u64) * 10));
for _ in 0..workload {
*mtx.lock() += 1;
}
println!("{i} finished");
})
.expect("should not fail"),
);
}
for h in handles {
h.join().expect("thread paniced");
}
println!("{:?}", &*mtx.lock());
assert_eq!(*mtx.lock(), workload * thread_count * 2);
}
Loading

0 comments on commit 2900f80

Please sign in to comment.