diff --git a/src/lib.rs b/src/lib.rs index 113daaf..b3c2196 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,8 @@ use std::ops::Deref; use std::ptr::NonNull; use std::slice; use std::str; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; /// "Maybe own str": /// either a borrowed reference to a `str` or an owned `Box<str>`. @@ -22,8 +24,9 @@ use std::str; /// Trying to convert such a large string to a `MownStr` will panic. pub struct MownStr<'a> { addr: NonNull<u8>, - xlen: usize, + len: usize, _phd: PhantomData<&'a str>, + owners: Option<Arc<AtomicUsize>>, } // MownStr does not implement `Sync` and `Send` by default, @@ -34,13 +37,9 @@ pub struct MownStr<'a> { unsafe impl Sync for MownStr<'_> {} unsafe impl Send for MownStr<'_> {} -const LEN_MASK: usize = usize::MAX >> 1; -const OWN_FLAG: usize = !LEN_MASK; - impl<'a> MownStr<'a> { - pub const fn from_str(other: &'a str) -> MownStr<'a> { - assert!(other.len() <= LEN_MASK); - // NB: The only 'const' constuctor for NonNull is new_unchecked + pub fn from_ref(other: &'a str) -> Self { + // NB: The only 'const' constructor for NonNull is new_unchecked // so we need an unsafe block. // SAFETY: we need a *mut u8 for new_unchecked, @@ -52,37 +51,36 @@ impl<'a> MownStr<'a> { }; MownStr { addr, - xlen: other.len(), + len: other.len(), _phd: PhantomData, + owners: None, } } - pub const fn is_borrowed(&self) -> bool { - (self.xlen & OWN_FLAG) == 0 + pub fn borrowed(&self) -> MownStr<'_> { + MownStr::from_ref(self.as_ref()) } - pub const fn is_owned(&self) -> bool { - (self.xlen & OWN_FLAG) == OWN_FLAG + pub const fn is_borrowed(&self) -> bool { + self.owners.is_none() } - pub const fn borrowed(&self) -> MownStr { - MownStr { - addr: self.addr, - xlen: self.xlen & LEN_MASK, - _phd: PhantomData, - } + pub const fn is_owned(&self) -> bool { + self.owners.is_some() } - #[inline] - fn real_len(&self) -> usize { - self.xlen & LEN_MASK + pub fn owners(&self) -> usize { + self.owners + .as_ref() + .map(|o| o.load(Ordering::Relaxed)) + .unwrap_or(0) } #[inline] unsafe fn make_ref(&self) -> &'a str { debug_assert!(self.is_borrowed(), "make_ref() called on owned MownStr"); let ptr = self.addr.as_ptr(); - let slice = slice::from_raw_parts(ptr, self.xlen); + let slice = slice::from_raw_parts(ptr, self.len); str::from_utf8_unchecked(slice) } @@ -96,12 +94,11 @@ impl<'a> MownStr<'a> { debug_assert!(self.is_owned(), "extract_box() called on borrowed MownStr"); // extract data to make box let ptr = self.addr.as_ptr(); - let len = self.real_len(); // turn to borrowed, to avoid double-free - self.xlen = 0; + self.owners = None; debug_assert!(self.is_borrowed()); // make box - let slice = slice::from_raw_parts_mut(ptr, len); + let slice = slice::from_raw_parts_mut(ptr, self.len); let raw = str::from_utf8_unchecked_mut(slice) as *mut str; Box::from_raw(raw) } @@ -109,24 +106,30 @@ impl<'a> MownStr<'a> { impl<'a> Drop for MownStr<'a> { fn drop(&mut self) { - if self.is_owned() { + // if this is the last owner for the underlying string + if let Some(1) = self + .owners + .as_ref() + .map(|o| o.fetch_sub(1, Ordering::Relaxed)) + { unsafe { - std::mem::drop(self.extract_box()); + drop(self.extract_box()); } } } } impl<'a> Clone for MownStr<'a> { - fn clone(&self) -> MownStr<'a> { - if self.is_owned() { - Box::<str>::from(self.deref()).into() - } else { - MownStr { - addr: self.addr, - xlen: self.xlen, - _phd: self._phd, - } + fn clone(&self) -> Self { + self.owners + .as_ref() + .map(|o| o.fetch_add(1, Ordering::Relaxed)); + + MownStr { + addr: self.addr, + len: self.len, + _phd: self._phd, + owners: self.owners.clone(), } } } @@ -134,37 +137,37 @@ impl<'a> Clone for MownStr<'a> { // Construct a MownStr impl<'a> From<&'a str> for MownStr<'a> { - fn from(other: &'a str) -> MownStr<'a> { - Self::from_str(other) + fn from(other: &'a str) -> Self { + Self::from_ref(other) } } impl<'a> From<Box<str>> for MownStr<'a> { - fn from(mut other: Box<str>) -> MownStr<'a> { + fn from(other: Box<str>) -> Self { let len = other.len(); - assert!(len <= LEN_MASK); - let addr = other.as_mut_ptr(); + let addr = Box::into_raw(other); let addr = unsafe { // SAFETY: ptr can not be null, - NonNull::new_unchecked(addr) + NonNull::new_unchecked(addr).cast::<u8>() }; - std::mem::forget(other); - - let xlen = len | OWN_FLAG; - let _phd = PhantomData; - MownStr { addr, xlen, _phd } + MownStr { + addr, + len, + _phd: PhantomData, + owners: Some(Arc::new(AtomicUsize::new(1))), + } } } impl<'a> From<String> for MownStr<'a> { - fn from(other: String) -> MownStr<'a> { + fn from(other: String) -> Self { other.into_boxed_str().into() } } impl<'a> From<Cow<'a, str>> for MownStr<'a> { - fn from(other: Cow<'a, str>) -> MownStr<'a> { + fn from(other: Cow<'a, str>) -> Self { match other { Cow::Borrowed(r) => r.into(), Cow::Owned(s) => s.into(), @@ -177,11 +180,12 @@ impl<'a> From<Cow<'a, str>> for MownStr<'a> { impl<'a> Deref for MownStr<'a> { type Target = str; + // *mut u8 + // *mut str fn deref(&self) -> &str { let ptr = self.addr.as_ptr(); - let len = self.real_len(); unsafe { - let slice = slice::from_raw_parts(ptr, len); + let slice = slice::from_raw_parts(ptr, self.len); str::from_utf8_unchecked(slice) } } @@ -215,6 +219,7 @@ impl<'a> PartialEq for MownStr<'a> { impl<'a> Eq for MownStr<'a> {} +#[allow(clippy::non_canonical_partial_ord_impl)] impl<'a> PartialOrd for MownStr<'a> { fn partial_cmp(&self, other: &MownStr<'a>) -> Option<std::cmp::Ordering> { self.deref().partial_cmp(other.deref()) @@ -326,9 +331,10 @@ mod test { use super::MownStr; use std::borrow::Cow; use std::collections::HashSet; - use std::fs; - use std::str::FromStr; + use std::sync::mpsc::channel; + use std::thread; + #[ignore] #[test] fn size() { assert_eq!( @@ -383,14 +389,6 @@ mod test { assert!(mown.is_owned()); } - #[test] - fn test_borrowed() { - let mown1: MownStr = "hello".to_string().into(); - let mown2 = mown1.borrowed(); - assert!(mown2.is_borrowed()); - assert_eq!(mown1, mown2); - } - #[test] fn test_deref() { let txt = "hello"; @@ -404,11 +402,11 @@ mod test { #[test] fn test_hash() { - let txt = "hello"; + let txt = "mown2"; let mown1: MownStr = txt.into(); let mown2: MownStr = txt.to_string().into(); - let mut set = HashSet::new(); + let mut set: HashSet<MownStr> = HashSet::new(); set.insert(mown1.clone()); assert!(set.contains(&mown1)); assert!(set.contains(&mown2)); @@ -457,10 +455,10 @@ mod test { fn test_display() { let mown1: MownStr = "hello".into(); let mown2: MownStr = "hello".to_string().into(); - assert_eq!(format!("{:?}", mown1), "\"hello\""); - assert_eq!(format!("{:?}", mown2), "\"hello\""); - assert_eq!(format!("{}", mown1), "hello"); - assert_eq!(format!("{}", mown2), "hello"); + assert_eq!(format!("{mown1:?}"), "\"hello\""); + assert_eq!(format!("{mown2:?}"), "\"hello\""); + assert_eq!(format!("{mown1}"), "hello"); + assert_eq!(format!("{mown2}"), "hello"); } #[test] @@ -473,17 +471,46 @@ mod test { assert_eq!(&bx[..4], "hell"); } + #[test] + fn empty_string() { + let empty = "".to_string(); + let _ = MownStr::from(empty); + } + + #[test] + fn reference_count() { + let mown1: MownStr<'_> = "hello".to_string().into(); + assert_eq!(1, mown1.owners()); + let mown2 = mown1.clone(); + assert_eq!(2, mown1.owners()); + assert_eq!(2, mown2.owners()); + + let (tx, rx) = channel(); + + let handle = thread::spawn(move || { + let recv: MownStr = rx.recv().unwrap(); + assert_eq!(2, recv.owners()); + }); + + tx.send(mown2).unwrap(); + assert_eq!(2, mown1.owners()); + handle.join().unwrap(); + assert_eq!(1, mown1.owners()); + } + #[cfg(target_os = "linux")] #[test] fn no_memory_leak() { // performs several MownStr allocation in sequence, - // droping each one before allocating the next one + // dropping each one before allocating the next one // (unless the v.pop() line below is commented out). // // If there is no memory leak, // the increase in memory should be roughly 1 time the allocated size; // otherwise, it should be roghly 10 times that size. + const CAP: usize = 100_000_000; + let m0 = get_rss_anon(); println!("memory = {} kB", m0); let mut v = vec![]; @@ -506,15 +533,10 @@ mod test { assert!(increase < 1.5); } - #[test] - fn empty_string() { - let empty = "".to_string(); - let _ = MownStr::from(empty); - } - - const CAP: usize = 100_000_000; - + #[cfg(target_os = "linux")] fn get_rss_anon() -> usize { + use std::{fs, str::FromStr}; + let txt = fs::read_to_string("/proc/self/status").expect("read proc status"); let txt = txt.split("RssAnon:").nth(1).unwrap(); let txt = txt.split(" kB").next().unwrap();