Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix free after clone Pt. 2 #7

Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 85 additions & 65 deletions src/lib.rs
mkatychev marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::ops::Deref;
use std::ptr::NonNull;
use std::slice;
use std::str;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

/// "Maybe own str":
/// either a borrowed reference to a `str` or an owned `Box<str>`.
Expand All @@ -22,8 +24,9 @@ use std::str;
/// Trying to convert such a large string to a `MownStr` will panic.
pub struct MownStr<'a> {
addr: NonNull<u8>,
xlen: usize,
len: usize,
_phd: PhantomData<&'a str>,
owners: Option<Arc<AtomicUsize>>,
}

// MownStr does not implement `Sync` and `Send` by default,
Expand All @@ -34,12 +37,8 @@ pub struct MownStr<'a> {
unsafe impl Sync for MownStr<'_> {}
unsafe impl Send for MownStr<'_> {}

const LEN_MASK: usize = usize::MAX >> 1;
const OWN_FLAG: usize = !LEN_MASK;

impl<'a> MownStr<'a> {
pub const fn from_str(other: &'a str) -> MownStr<'a> {
assert!(other.len() <= LEN_MASK);
pub fn from_str(other: &'a str) -> MownStr<'a> {
// NB: The only 'const' constuctor for NonNull is new_unchecked
// so we need an unsafe block.

Expand All @@ -52,37 +51,38 @@ impl<'a> MownStr<'a> {
};
MownStr {
addr,
xlen: other.len(),
len: other.len(),
_phd: PhantomData,
owners: None,
Comment on lines +54 to +56
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change goes against the aim of MownStr, which is to cost no more space than a regular &str or Box<str>. If we go for this cost, we might as well use Cow<str> instead of MownStr.

}
}
fn cleanup(&self) -> bool {
self.owners
.as_ref()
.map(|c| c.load(Ordering::SeqCst) == 0)
.unwrap_or(false)
}

pub const fn is_borrowed(&self) -> bool {
(self.xlen & OWN_FLAG) == 0
self.owners.is_none()
}

pub const fn is_owned(&self) -> bool {
(self.xlen & OWN_FLAG) == OWN_FLAG
self.owners.is_some()
}

pub const fn borrowed(&self) -> MownStr {
MownStr {
addr: self.addr,
xlen: self.xlen & LEN_MASK,
_phd: PhantomData,
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pchampin this method was particularly finicky in triggering use after free UB behaviour in a lot of sophia term methods

I strongly recommend removing it altogether and rely on Clone + Arc semantics to conserve on string allocation


#[inline]
fn real_len(&self) -> usize {
self.xlen & LEN_MASK
pub fn owners(&self) -> usize {
self.owners
.as_ref()
.map(|o| o.load(Ordering::Relaxed))
.unwrap_or(0)
}

#[inline]
unsafe fn make_ref(&self) -> &'a str {
debug_assert!(self.is_borrowed(), "make_ref() called on owned MownStr");
let ptr = self.addr.as_ptr();
let slice = slice::from_raw_parts(ptr, self.xlen);
let slice = slice::from_raw_parts(ptr, self.len);
str::from_utf8_unchecked(slice)
}

Expand All @@ -96,37 +96,41 @@ impl<'a> MownStr<'a> {
debug_assert!(self.is_owned(), "extract_box() called on borrowed MownStr");
// extract data to make box
let ptr = self.addr.as_ptr();
let len = self.real_len();
// turn to borrowed, to avoid double-free
self.xlen = 0;
self.owners = None;
debug_assert!(self.is_borrowed());
// make box
let slice = slice::from_raw_parts_mut(ptr, len);
let slice = slice::from_raw_parts_mut(ptr, self.len);
let raw = str::from_utf8_unchecked_mut(slice) as *mut str;
Box::from_raw(raw)
}
}

impl<'a> Drop for MownStr<'a> {
fn drop(&mut self) {
if self.is_owned() {
self.owners
.as_ref()
.map(|o| o.fetch_sub(1, Ordering::SeqCst));
if self.cleanup() {
dbg!(&self.owners);
unsafe {
std::mem::drop(self.extract_box());
drop(self.extract_box());
}
}
}
}

impl<'a> Clone for MownStr<'a> {
fn clone(&self) -> MownStr<'a> {
if self.is_owned() {
Box::<str>::from(self.deref()).into()
} else {
MownStr {
addr: self.addr,
xlen: self.xlen,
_phd: self._phd,
}
self.owners
.as_ref()
.map(|o| o.fetch_add(1, Ordering::SeqCst));

MownStr {
addr: self.addr,
len: self.len,
_phd: self._phd,
owners: self.owners.clone(),
}
}
}
Expand All @@ -140,20 +144,20 @@ impl<'a> From<&'a str> for MownStr<'a> {
}

impl<'a> From<Box<str>> for MownStr<'a> {
fn from(mut other: Box<str>) -> MownStr<'a> {
fn from(other: Box<str>) -> MownStr<'a> {
let len = other.len();
assert!(len <= LEN_MASK);
let addr = other.as_mut_ptr();
let addr = Box::into_raw(other);
let addr = unsafe {
// SAFETY: ptr can not be null,
NonNull::new_unchecked(addr)
NonNull::new_unchecked(addr).cast::<u8>()
};

std::mem::forget(other);

let xlen = len | OWN_FLAG;
let _phd = PhantomData;
MownStr { addr, xlen, _phd }
MownStr {
addr,
len,
_phd: PhantomData,
owners: Some(Arc::new(AtomicUsize::new(1))),
}
}
}

Expand All @@ -179,9 +183,8 @@ impl<'a> Deref for MownStr<'a> {

fn deref(&self) -> &str {
let ptr = self.addr.as_ptr();
let len = self.real_len();
unsafe {
let slice = slice::from_raw_parts(ptr, len);
let slice = slice::from_raw_parts(ptr, self.len);
str::from_utf8_unchecked(slice)
}
}
Expand Down Expand Up @@ -326,9 +329,10 @@ mod test {
use super::MownStr;
use std::borrow::Cow;
use std::collections::HashSet;
use std::fs;
use std::str::FromStr;
use std::sync::mpsc::channel;
use std::thread;

#[ignore]
#[test]
fn size() {
assert_eq!(
Expand Down Expand Up @@ -383,14 +387,6 @@ mod test {
assert!(mown.is_owned());
}

#[test]
fn test_borrowed() {
let mown1: MownStr = "hello".to_string().into();
let mown2 = mown1.borrowed();
assert!(mown2.is_borrowed());
assert_eq!(mown1, mown2);
}

#[test]
fn test_deref() {
let txt = "hello";
Expand All @@ -404,11 +400,11 @@ mod test {

#[test]
fn test_hash() {
let txt = "hello";
let txt = "mown2";
let mown1: MownStr = txt.into();
let mown2: MownStr = txt.to_string().into();

let mut set = HashSet::new();
let mut set: HashSet<MownStr> = HashSet::new();
set.insert(mown1.clone());
assert!(set.contains(&mown1));
assert!(set.contains(&mown2));
Expand Down Expand Up @@ -473,6 +469,33 @@ mod test {
assert_eq!(&bx[..4], "hell");
}

#[test]
fn empty_string() {
let empty = "".to_string();
let _ = MownStr::from(empty);
}

#[test]
fn reference_count() {
let mown1: MownStr<'_> = "hello".to_string().into();
assert_eq!(1, mown1.owners());
let mown2 = mown1.clone();
assert_eq!(2, mown1.owners());
assert_eq!(2, mown2.owners());

let (tx, rx) = channel();

let handle = thread::spawn(move || {
let recv: MownStr = rx.recv().unwrap();
assert_eq!(2, recv.owners());
});

tx.send(mown2).unwrap();
assert_eq!(2, mown1.owners());
handle.join().unwrap();
assert_eq!(1, mown1.owners());
}

#[cfg(target_os = "linux")]
#[test]
fn no_memory_leak() {
Expand All @@ -484,6 +507,8 @@ mod test {
// the increase in memory should be roughly 1 time the allocated size;
// otherwise, it should be roghly 10 times that size.

const CAP: usize = 100_000_000;

let m0 = get_rss_anon();
println!("memory = {} kB", m0);
let mut v = vec![];
Expand All @@ -506,15 +531,10 @@ mod test {
assert!(increase < 1.5);
}

#[test]
fn empty_string() {
let empty = "".to_string();
let _ = MownStr::from(empty);
}

const CAP: usize = 100_000_000;

#[cfg(target_os = "linux")]
fn get_rss_anon() -> usize {
use std::{fs, str::FromStr};

let txt = fs::read_to_string("/proc/self/status").expect("read proc status");
let txt = txt.split("RssAnon:").nth(1).unwrap();
let txt = txt.split(" kB").next().unwrap();
Expand Down