Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure TLS destructors run before thread joins in SGX #84409

Merged
merged 3 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions library/std/src/sys/sgx/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ unsafe extern "C" fn tcs_init(secondary: bool) {
extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn {
// FIXME: how to support TLS in library mode?
let tls = Box::new(tls::Tls::new());
let _tls_guard = unsafe { tls.activate() };
let tls_guard = unsafe { tls.activate() };

if secondary {
super::thread::Thread::entry();
let join_notifier = super::thread::Thread::entry();
drop(tls_guard);
drop(join_notifier);

EntryReturn(0, 0)
} else {
Expand Down
69 changes: 61 additions & 8 deletions library/std/src/sys/sgx/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,37 @@ pub struct Thread(task_queue::JoinHandle);

pub const DEFAULT_MIN_STACK_SIZE: usize = 4096;

pub use self::task_queue::JoinNotifier;

mod task_queue {
use crate::sync::mpsc;
use super::wait_notify;
use crate::sync::{Mutex, MutexGuard, Once};

pub type JoinHandle = mpsc::Receiver<()>;
pub type JoinHandle = wait_notify::Waiter;

pub struct JoinNotifier(Option<wait_notify::Notifier>);

impl Drop for JoinNotifier {
fn drop(&mut self) {
self.0.take().unwrap().notify();
}
}

pub(super) struct Task {
p: Box<dyn FnOnce()>,
done: mpsc::Sender<()>,
done: JoinNotifier,
}

impl Task {
pub(super) fn new(p: Box<dyn FnOnce()>) -> (Task, JoinHandle) {
let (done, recv) = mpsc::channel();
let (done, recv) = wait_notify::new();
let done = JoinNotifier(Some(done));
(Task { p, done }, recv)
}

pub(super) fn run(self) {
pub(super) fn run(self) -> JoinNotifier {
(self.p)();
let _ = self.done.send(());
self.done
}
}

Expand All @@ -47,6 +58,48 @@ mod task_queue {
}
}

/// This module provides a synchronization primitive that does not use thread
/// local variables. This is needed for signaling that a thread has finished
/// execution. The signal is sent once all TLS destructors have finished at
/// which point no new thread locals should be created.
pub mod wait_notify {
use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable};
use crate::sync::Arc;

pub struct Notifier(Arc<SpinMutex<WaitVariable<bool>>>);

impl Notifier {
/// Notify the waiter. The waiter is either notified right away (if
/// currently blocked in `Waiter::wait()`) or later when it calls the
/// `Waiter::wait()` method.
pub fn notify(self) {
let mut guard = self.0.lock();
*guard.lock_var_mut() = true;
let _ = WaitQueue::notify_one(guard);
}
}

pub struct Waiter(Arc<SpinMutex<WaitVariable<bool>>>);

impl Waiter {
/// Wait for a notification. If `Notifier::notify()` has already been
/// called, this will return immediately, otherwise the current thread
/// is blocked until notified.
pub fn wait(self) {
let guard = self.0.lock();
if *guard.lock_var() {
return;
}
WaitQueue::wait(guard, || {});
}
}

pub fn new() -> (Notifier, Waiter) {
let inner = Arc::new(SpinMutex::new(WaitVariable::new(false)));
(Notifier(inner.clone()), Waiter(inner))
}
}

impl Thread {
// unsafe: see thread::Builder::spawn_unchecked for safety requirements
pub unsafe fn new(_stack: usize, p: Box<dyn FnOnce()>) -> io::Result<Thread> {
Expand All @@ -57,7 +110,7 @@ impl Thread {
Ok(Thread(handle))
}

pub(super) fn entry() {
pub(super) fn entry() -> JoinNotifier {
let mut pending_tasks = task_queue::lock();
let task = rtunwrap!(Some, pending_tasks.pop());
drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary
Expand All @@ -78,7 +131,7 @@ impl Thread {
}

pub fn join(self) {
let _ = self.0.recv();
self.0.wait();
}
}

Expand Down
55 changes: 54 additions & 1 deletion library/std/src/thread/local/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::cell::{Cell, UnsafeCell};
use crate::sync::mpsc::{channel, Sender};
use crate::sync::atomic::{AtomicBool, Ordering};
use crate::sync::mpsc::{self, channel, Sender};
use crate::thread::{self, LocalKey};
use crate::thread_local;

Expand Down Expand Up @@ -207,3 +208,55 @@ fn dtors_in_dtors_in_dtors_const_init() {
});
rx.recv().unwrap();
}

// This test tests that TLS destructors have run before the thread joins. The
// test has no false positives (meaning: if the test fails, there's actually
// an ordering problem). It may have false negatives, where the test passes but
// join is not guaranteed to be after the TLS destructors. However, false
// negatives should be exceedingly rare due to judicious use of
// thread::yield_now and running the test several times.
#[test]
fn join_orders_after_tls_destructors() {
static THREAD2_LAUNCHED: AtomicBool = AtomicBool::new(false);

for _ in 0..10 {
let (tx, rx) = mpsc::sync_channel(0);
THREAD2_LAUNCHED.store(false, Ordering::SeqCst);

let jh = thread::spawn(move || {
struct RecvOnDrop(Cell<Option<mpsc::Receiver<()>>>);

impl Drop for RecvOnDrop {
fn drop(&mut self) {
let rx = self.0.take().unwrap();
while !THREAD2_LAUNCHED.load(Ordering::SeqCst) {
thread::yield_now();
}
rx.recv().unwrap();
}
}

thread_local! {
static TL_RX: RecvOnDrop = RecvOnDrop(Cell::new(None));
}

TL_RX.with(|v| v.0.set(Some(rx)))
});

let tx_clone = tx.clone();
let jh2 = thread::spawn(move || {
THREAD2_LAUNCHED.store(true, Ordering::SeqCst);
jh.join().unwrap();
tx_clone.send(()).expect_err(
"Expecting channel to be closed because thread 1 TLS destructors must've run",
);
});

while !THREAD2_LAUNCHED.load(Ordering::SeqCst) {
thread::yield_now();
}
thread::yield_now();
tx.send(()).expect("Expecting channel to be live because thread 2 must block on join");
jh2.join().unwrap();
}
}