diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index 186f1a08d..c683eb6f9 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -27,7 +27,9 @@ use std::error::Error; use std::fmt; use std::io; use std::marker::PhantomData; +use std::mem; use std::str::FromStr; +use std::time::Duration; #[macro_use] mod log; @@ -129,6 +131,9 @@ pub struct ThreadPoolBuilder { /// Closure to compute the name of a thread. get_thread_name: Option String>>, + /// The wait time before bringing up a thread. + get_thread_wait_time: Option Option>>, + /// The stack size for the created worker threads stack_size: Option, @@ -188,6 +193,7 @@ impl Default for ThreadPoolBuilder { num_threads: 0, panic_handler: None, get_thread_name: None, + get_thread_wait_time: None, stack_size: None, start_handler: None, exit_handler: None, @@ -219,7 +225,7 @@ impl ThreadPoolBuilder { /// default spawn and those set by [`spawn_handler`](#method.spawn_handler). impl ThreadPoolBuilder where - S: ThreadSpawn, + S: ThreadSpawn + Send + 'static, { /// Create a new `ThreadPool` initialized using this configuration. pub fn build(self) -> Result { @@ -295,19 +301,29 @@ impl ThreadPoolBuilder { { let result = crossbeam_utils::thread::scope(|scope| { let wrapper = &wrapper; - let pool = self - .spawn_handler(|thread| { - let mut builder = scope.builder(); - if let Some(name) = thread.name() { - builder = builder.name(name.to_string()); - } - if let Some(size) = thread.stack_size() { - builder = builder.stack_size(size); - } - builder.spawn(move |_| wrapper(thread))?; - Ok(()) - }) - .build()?; + let spawn_handler = move |thread: ThreadBuilder| { + let mut builder = scope.builder(); + if let Some(name) = thread.name() { + builder = builder.name(name.to_string()); + } + if let Some(size) = thread.stack_size() { + builder = builder.stack_size(size); + } + builder.spawn(move |_| wrapper(thread))?; + Ok(()) + }; + + // Allocate `spawn_handler` on the heap and make it `'static`. + // This is safe because we wait for the thread pool to end before + // returning from this closure. Once the thread pool has ended + // no more uses of `spawn_handler` can occur. + let spawn_handler: Box io::Result<()> + Sync + Send + '_> = + Box::new(spawn_handler); + let spawn_handler: Box< + dyn Fn(ThreadBuilder) -> io::Result<()> + Sync + Send + 'static, + > = unsafe { mem::transmute(spawn_handler) }; + + let pool = self.spawn_handler(spawn_handler).build()?; let result = unwind::halt_unwinding(|| with_pool(&pool)); pool.wait_until_stopped(); match result { @@ -388,6 +404,7 @@ impl ThreadPoolBuilder { num_threads: self.num_threads, panic_handler: self.panic_handler, get_thread_name: self.get_thread_name, + get_thread_wait_time: self.get_thread_wait_time, stack_size: self.stack_size, start_handler: self.start_handler, exit_handler: self.exit_handler, @@ -445,6 +462,22 @@ impl ThreadPoolBuilder { self } + /// Get the thread wait time for the thread with the given index. + fn get_thread_wait_time(&mut self, index: usize) -> Option { + let f = self.get_thread_wait_time.as_mut()?; + f(index) + } + + /// Set a closure which takes a thread index and returns + /// the thread's wait time. + pub fn thread_wait_time(mut self, closure: F) -> Self + where + F: FnMut(usize) -> Option + 'static, + { + self.get_thread_wait_time = Some(Box::new(closure)); + self + } + /// Set the number of threads to be used in the rayon threadpool. /// /// If you specify a non-zero number of threads using this @@ -745,6 +778,7 @@ impl fmt::Debug for ThreadPoolBuilder { let ThreadPoolBuilder { ref num_threads, ref get_thread_name, + ref get_thread_wait_time, ref panic_handler, ref stack_size, ref deadlock_handler, @@ -765,6 +799,7 @@ impl fmt::Debug for ThreadPoolBuilder { } } let get_thread_name = get_thread_name.as_ref().map(|_| ClosurePlaceholder); + let get_thread_wait_time = get_thread_wait_time.as_ref().map(|_| ClosurePlaceholder); let panic_handler = panic_handler.as_ref().map(|_| ClosurePlaceholder); let deadlock_handler = deadlock_handler.as_ref().map(|_| ClosurePlaceholder); let start_handler = start_handler.as_ref().map(|_| ClosurePlaceholder); @@ -775,6 +810,7 @@ impl fmt::Debug for ThreadPoolBuilder { f.debug_struct("ThreadPoolBuilder") .field("num_threads", num_threads) .field("get_thread_name", &get_thread_name) + .field("get_thread_wait_time", &get_thread_wait_time) .field("panic_handler", &panic_handler) .field("stack_size", &stack_size) .field("deadlock_handler", &deadlock_handler) diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index b63a79ff1..a7fa949b6 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -21,8 +21,9 @@ use std::ptr; #[allow(deprecated)] use std::sync::atomic::ATOMIC_USIZE_INIT; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Once}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard, Once}; use std::thread; +use std::time::Duration; use std::usize; /// Thread builder used for customization via @@ -132,10 +133,33 @@ where } } +struct ThreadSpawningState { + /// The number of threads that we have started spawning. + spawned_threads: usize, + + /// Indicates if a thread being spawned right now. + spawning_thread: bool, + + /// The number of threads that the thread pool should have at this moment. + /// If this is lower than the spawned thread count, the thread pool will spawn + /// more threads until the target is hit. If this is set lower than the spawned thread + /// count, nothing will happen. + thread_target: usize, + + /// When this is set to true, no more threads are allowed to join the thread pool + /// and the `stopped` latch in `ThreadInfo` of the threads that haven't joined is set. + terminating: bool, +} + pub struct Registry { thread_infos: Vec, sleep: Sleep, injected_jobs: SegQueue, + thread_spawn: Mutex>, + + /// The stack size for the created worker threads + stack_size: Option, + panic_handler: Option>, pub(crate) deadlock_handler: Option>, start_handler: Option>, @@ -157,6 +181,17 @@ pub struct Registry { // These are always owned by some other job (e.g., one injected by `ThreadPool::install()`) // and that job will keep the pool alive. terminate_latch: CountLatch, + + // Used to avoid races when adding threads to the thread pool. + thread_spawning_state: Mutex, + + // The `thread_spawning_state` lock also protects this, but is stored outside so it can + // be efficiently accessed by work stealing, which does not need an up to date value. + active_threads: AtomicUsize, + + // Used with `thread_spawning_state` to wait for a duration or thread pool termination + // during thread startup. + terminate_cond_var: Condvar, } /// //////////////////////////////////////////////////////////////////////// @@ -180,7 +215,7 @@ pub(super) fn init_global_registry( builder: ThreadPoolBuilder, ) -> Result<&'static Arc, ThreadPoolBuildError> where - S: ThreadSpawn, + S: ThreadSpawn + Send + 'static, { set_global_registry(|| Registry::new(builder)) } @@ -219,27 +254,36 @@ impl Registry { mut builder: ThreadPoolBuilder, ) -> Result, ThreadPoolBuildError> where - S: ThreadSpawn, + S: ThreadSpawn + Send + 'static, { let n_threads = builder.get_num_threads(); let breadth_first = builder.get_breadth_first(); - let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads) - .map(|_| { - let worker = if breadth_first { - Worker::new_fifo() - } else { - Worker::new_lifo() - }; + let queues = (0..n_threads).map(|_| { + let worker = if breadth_first { + Worker::new_fifo() + } else { + Worker::new_lifo() + }; - let stealer = worker.stealer(); - (worker, stealer) - }) - .unzip(); + let stealer = worker.stealer(); + (worker, stealer) + }); let registry = Arc::new(Registry { - thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(), - sleep: Sleep::new(n_threads), + thread_infos: queues + .enumerate() + .map(|(index, (worker, stealer))| { + ThreadInfo::new( + stealer, + worker, + builder.get_thread_name(index), + builder.get_thread_wait_time(index), + ) + }) + .collect(), + sleep: Sleep::new(), + stack_size: builder.stack_size, injected_jobs: SegQueue::new(), terminate_latch: CountLatch::new(), panic_handler: builder.take_panic_handler(), @@ -248,23 +292,22 @@ impl Registry { exit_handler: builder.take_exit_handler(), acquire_thread_handler: builder.take_acquire_thread_handler(), release_thread_handler: builder.take_release_thread_handler(), + thread_spawn: Mutex::new(Box::new(builder.spawn_handler)), + thread_spawning_state: Mutex::new(ThreadSpawningState { + spawned_threads: 0, + spawning_thread: false, + thread_target: 1, + terminating: false, + }), + active_threads: AtomicUsize::new(0), + terminate_cond_var: Condvar::new(), }); // If we return early or panic, make sure to terminate existing threads. let t1000 = Terminator(®istry); - for (index, worker) in workers.into_iter().enumerate() { - let thread = ThreadBuilder { - name: builder.get_thread_name(index), - stack_size: builder.get_stack_size(), - registry: registry.clone(), - worker, - index, - }; - if let Err(e) = builder.get_spawn_handler().spawn(thread) { - return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e))); - } - } + // Spawn the initial thread + registry.try_spawn_thread(registry.thread_spawning_state.lock().unwrap())?; // Returning normally now, without termination. mem::forget(t1000); @@ -272,6 +315,74 @@ impl Registry { Ok(registry.clone()) } + /// Changes the number of threads that the thread pool should have at this moment. + /// Threads will be spawned until this target is met. However if this target is lowered, + /// threads that are already spawned will remain a part of the thread pool. + pub fn set_thread_target(self: &Arc, target: usize) { + assert!(target > 0 && target <= self.num_threads()); + + let mut state = self.thread_spawning_state.lock().unwrap(); + state.thread_target = target; + + // Try to spawn a new thread if wanted + self.try_spawn_thread(state).unwrap(); + } + + /// Spawns a new thread if suitable conditions to add a thread exists. + fn try_spawn_thread( + self: &Arc, + mut state: MutexGuard<'_, ThreadSpawningState>, + ) -> Result<(), ThreadPoolBuildError> { + let index = { + if state.spawning_thread { + // Another thread is currently spawning already + return Ok(()); + } + + if state.spawned_threads >= state.thread_target { + // We are not supposed to spawn more threads right now + return Ok(()); + } + + if state.terminating { + // We are not supposed to spawn more threads after the thread pool is in a + // terminating state + return Ok(()); + } + + state.spawning_thread = true; + + let index = state.spawned_threads; + state.spawned_threads += 1; + + mem::drop(state); + + index + }; + + // Steal the work queue for this thread + let worker = self.thread_infos[index] + .worker + .lock() + .unwrap() + .take() + .unwrap(); + + let thread = ThreadBuilder { + name: self.thread_infos[index].name.clone(), + stack_size: self.stack_size, + registry: self.clone(), + worker, + index, + }; + + self.thread_spawn + .lock() + .unwrap() + .spawn(thread) + .map_err(|e| ThreadPoolBuildError::new(ErrorKind::IOError(e))) + } + pub fn current() -> Arc { unsafe { let worker_thread = WorkerThread::current(); @@ -530,7 +641,36 @@ impl Registry { /// extant work is completed. pub(super) fn terminate(&self) { self.terminate_latch.set(); - self.sleep.tickle(usize::MAX); + + if self.terminate_latch.probe() { + self.sleep.tickle(usize::MAX); + + let active_threads = { + let mut state = self.thread_spawning_state.lock().unwrap(); + + if state.terminating { + // Some other thread will mark the non started thread as stopped + return; + } + + // Load `active_threads` while we hold the mutex so + // it won't race with thread startup. + let active_threads = self.active_threads.load(Ordering::Relaxed); + + state.terminating = true; + + active_threads + }; + + // Wake up any threads which have trottled their startup since we need them + // to terminate now. + self.terminate_cond_var.notify_all(); + + // Mark threads that did not start up yet as terminated + for index in active_threads..self.num_threads() { + self.thread_infos[index].stopped.set(); + } + } } } @@ -569,14 +709,31 @@ struct ThreadInfo { /// the "stealer" half of the worker's deque stealer: Stealer, + + /// Holds the worker's half of the worker's deque, until this thread starts. + worker: Mutex>>, + + /// The name of this worker thread, if any. + name: Option, + + /// The wait time before bringing up the thread, if any. + wait_time: Option, } impl ThreadInfo { - fn new(stealer: Stealer) -> ThreadInfo { + fn new( + stealer: Stealer, + worker: Worker, + name: Option, + wait_time: Option, + ) -> ThreadInfo { ThreadInfo { primed: LockLatch::new(), stopped: LockLatch::new(), stealer, + worker: Mutex::new(Some(worker)), + name, + wait_time, } } } @@ -741,7 +898,7 @@ impl WorkerThread { debug_assert!(self.local_deque_is_empty()); // otherwise, try to steal - let num_threads = self.registry.thread_infos.len(); + let num_threads = self.registry.active_threads.load(Ordering::Acquire); if num_threads <= 1 { return None; } @@ -782,14 +939,59 @@ unsafe fn main_loop(worker: Worker, registry: Arc, index: usiz }; WorkerThread::set_current(worker_thread); - // let registry know we are ready to do work - registry.thread_infos[index].primed.set(); - // Worker threads should not panic. If they do, just abort, as the // internal state of the threadpool is corrupted. Note that if // **user code** panics, we should catch that and redirect. let abort_guard = unwind::AbortIfPanic; + if let Some(duration) = registry.thread_infos[index].wait_time { + registry + .terminate_cond_var + .wait_timeout(registry.thread_spawning_state.lock().unwrap(), duration); + } + + if registry.terminate_latch.probe() { + // The thread pool terminated while we where starting up or waiting. + + // Normal termination, do not abort. + mem::forget(abort_guard); + + return; + } + + registry.acquire_thread(); + + { + let mut state = registry.thread_spawning_state.lock().unwrap(); + + // If the registry terminated before we got the lock, we must pretend that this + // thread was never a part of the thread pool, so we release the thread and return. + if state.terminating { + mem::drop(state); + + // The thread pool terminated while we where waiting + registry.release_thread(); + + // Normal termination, do not abort. + mem::forget(abort_guard); + + return; + } + + state.spawning_thread = false; + + // Mark ourself as active so work can be stolen from us. + // We modify this while holding the mutex so it won't race with terminating + // the thread pool. + registry.active_threads.fetch_add(1, Ordering::Release); + + // Try to spawn a new thread if wanted + registry.try_spawn_thread(state).unwrap(); + } + + // Notify the sleep module of the new worker + registry.sleep.new_worker(); + // Inform a user callback that we started a thread. if let Some(ref handler) = registry.start_handler { let registry = registry.clone(); @@ -801,7 +1003,9 @@ unsafe fn main_loop(worker: Worker, registry: Arc, index: usiz } } - registry.acquire_thread(); + // let registry know we are ready to do work + registry.thread_infos[index].primed.set(); + worker_thread.wait_until(®istry.terminate_latch); // Should not be any work left in our queue. diff --git a/rayon-core/src/sleep/mod.rs b/rayon-core/src/sleep/mod.rs index f9c6f035b..9d3693518 100644 --- a/rayon-core/src/sleep/mod.rs +++ b/rayon-core/src/sleep/mod.rs @@ -10,7 +10,7 @@ use std::thread; use std::usize; struct SleepData { - /// The number of threads in the thread pool. + /// The number of threads in the thread pool that have started up so far. worker_count: usize, /// The number of threads in the thread pool which are running and @@ -45,18 +45,25 @@ const ROUNDS_UNTIL_SLEEPY: usize = 32; const ROUNDS_UNTIL_ASLEEP: usize = 64; impl Sleep { - pub(super) fn new(worker_count: usize) -> Sleep { + pub(super) fn new() -> Sleep { Sleep { state: AtomicUsize::new(AWAKE), data: Mutex::new(SleepData { - worker_count, - active_threads: worker_count, + worker_count: 0, + active_threads: 0, blocked_threads: 0, }), tickle: Condvar::new(), } } + pub(super) fn new_worker(&self) { + let mut data = self.data.lock().unwrap(); + + data.worker_count += 1; + data.active_threads += 1; + } + /// Mark a Rayon worker thread as blocked. This triggers the deadlock handler /// if no other worker thread is active #[inline] diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 42ab28778..bb1d457cd 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -64,7 +64,7 @@ impl ThreadPool { builder: ThreadPoolBuilder, ) -> Result where - S: ThreadSpawn, + S: ThreadSpawn + Send + 'static, { let registry = Registry::new(builder)?; Ok(ThreadPool { registry })