diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 0799ab9a6c79e..133ea4b60bf16 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -53,11 +53,12 @@ target_include_directories(onnxruntime_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR if(onnxruntime_USE_NSYNC) target_compile_definitions(onnxruntime_common PUBLIC USE_NSYNC) endif() -if(onnxruntime_USE_EIGEN_THREADPOOL) - target_include_directories(onnxruntime_common PRIVATE ${eigen_INCLUDE_DIRS}) - target_compile_definitions(onnxruntime_common PUBLIC USE_EIGEN_THREADPOOL) - add_dependencies(onnxruntime_common ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +target_include_directories(onnxruntime_common PUBLIC ${eigen_INCLUDE_DIRS}) +if(NOT onnxruntime_USE_OPENMP) + target_compile_definitions(onnxruntime_common PUBLIC EIGEN_USE_THREADS) endif() +add_dependencies(onnxruntime_common ${onnxruntime_EXTERNAL_DEPENDENCIES}) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/common DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) set_target_properties(onnxruntime_common PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 8fdbca776745f..5d12ac067bd83 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -190,5 +190,5 @@ else() endif() add_library(onnxruntime_mlas STATIC ${mlas_common_srcs} ${mlas_platform_srcs}) -target_include_directories(onnxruntime_mlas PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib) +target_include_directories(onnxruntime_mlas PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib ${eigen_INCLUDE_DIRS}) set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") diff --git a/include/onnxruntime/core/platform/threadpool.h b/include/onnxruntime/core/platform/threadpool.h index 66952591ce470..3337583612065 100644 --- a/include/onnxruntime/core/platform/threadpool.h +++ b/include/onnxruntime/core/platform/threadpool.h @@ -7,12 +7,27 @@ #include #include +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#pragma warning(push) +#pragma warning(disable : 4267) +#endif +#include +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#else +#pragma warning(pop) +#endif + namespace onnxruntime { namespace concurrency { /** * Generic class for instantiating thread pools. + * Don't put any object of this type into a global variable in a Win32 DLL. */ class ThreadPool { public: @@ -43,14 +58,10 @@ class ThreadPool { int CurrentThreadId() const; - /* - Ensure that the pool has terminated and cleaned up all threads cleanly. - */ - ~ThreadPool(); + Eigen::ThreadPool& GetHandler() { return impl_; } private: - class Impl; - std::unique_ptr impl_; + Eigen::ThreadPool impl_; }; } // namespace concurrency diff --git a/onnxruntime/core/common/task_thread_pool.h b/onnxruntime/core/common/task_thread_pool.h deleted file mode 100644 index 1cc0d64ecfd6b..0000000000000 --- a/onnxruntime/core/common/task_thread_pool.h +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* -Changed to use std::packaged_task instead of std::function so exceptions can be propagated. - -This also allows the task threadpool to be shared across multiple operators as the caller -can keep a container of the packaged_task futures to check when they have completed. Calling -WaitWorkComplete in that use case is invalid as there may be other concurrent usage of the -threadpool. - -Example of that usage: - - std::vector> task_results{}; - - for (...) { - std::packaged_task task{std::bind(lambda, i)}; - task_results.push_back(task.get_future()); - task_thread_pool.RunTask(std::move(task)); - } - - try { - // wait for all and propagate any exceptions - for (auto& future : task_results) - future.get(); - } catch (const std::exception& ex) { - ... - throw; - } - -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/logging/logging.h" -#include "core/platform/ort_mutex.h" - -namespace onnxruntime { - -class TaskThreadPool { - private: - struct task_element_t { - bool run_with_id; - std::packaged_task no_id; - std::packaged_task with_id; - - task_element_t(task_element_t&& other) noexcept { - run_with_id = other.run_with_id; - no_id = std::move(other.no_id); - with_id = std::move(other.with_id); - } - - explicit task_element_t(std::packaged_task&& f) - : run_with_id(false), no_id(std::move(f)) {} - - explicit task_element_t(std::packaged_task&& f) - : run_with_id(true), with_id(std::move(f)) {} - }; - - std::queue tasks_; - std::vector threads_; - OrtMutex mutex_; - OrtCondVar condition_; - OrtCondVar completed_; - bool running_; - bool complete_; - std::size_t available_; - std::size_t total_; - - public: - /// @brief Constructor. - explicit TaskThreadPool(std::size_t pool_size) - : threads_(pool_size), running_(true), complete_(true), available_(pool_size), total_(pool_size) { - for (std::size_t i = 0; i < pool_size; ++i) { - threads_[i] = std::thread(std::bind(&TaskThreadPool::MainLoop, this, i)); - } - } - - /// @brief Destructor. - ~TaskThreadPool() { - // Set running flag to false then notify all threads. - { - std::unique_lock lock(mutex_); - running_ = false; - condition_.notify_all(); - } - - try { - for (auto& t : threads_) { - t.join(); - } - } - // Suppress all exceptions. - catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "Exception joining threads in TaskThreadPool: " << ex.what(); - } - } - - int NumThreads() const { - return (int)threads_.size(); - } - - // This thread pool does not support ids - int CurrentThreadId() const { - return -1; - } - - void RunTask(std::packaged_task&& task) { - std::unique_lock lock(mutex_); - - // Set task and signal condition variable so that a worker thread will - // wake up and use the task. - tasks_.push(task_element_t(std::move(task))); - complete_ = false; - condition_.notify_one(); - } - - void RunTaskWithID(std::packaged_task&& task) { - std::unique_lock lock(mutex_); - - // Set task and signal condition variable so that a worker thread will - // wake up and use the task. - tasks_.push(task_element_t(std::move(task))); - complete_ = false; - condition_.notify_one(); - } - - /// @brief Wait for queue to be empty - void WaitWorkComplete() { - std::unique_lock lock(mutex_); - while (!complete_) - completed_.wait(lock); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TaskThreadPool); - - /// @brief Entry point for pool threads. - void MainLoop(std::size_t index) { - while (running_) { - // Wait on condition variable while the task is empty and - // the pool is still running. - std::unique_lock lock(mutex_); - while (tasks_.empty() && running_) { - condition_.wait(lock); - } - - // If pool is no longer running, break out of loop. - if (!running_) break; - - // Copy task locally and remove from the queue. This is - // done within its own scope so that the task object is - // destructed immediately after running the task. This is - // useful in the event that the function contains - // shared_ptr arguments bound via bind. - { - auto task = std::move(tasks_.front()); - tasks_.pop(); - // Decrement count, indicating thread is no longer available. - --available_; - - lock.unlock(); - - // Run the task. - try { - if (task.run_with_id) { - task.with_id(index); - } else { - task.no_id(); - } - } catch (const std::exception& /*ex*/) { - // LOGS_DEFAULT(ERROR) << "Exception running TaskThreadPool task: " << ex.what(); - throw; - } - - // Update status of empty, maybe - // Need to recover the lock first - lock.lock(); - - // Increment count, indicating thread is available. - ++available_; - if (tasks_.empty() && available_ == total_) { - complete_ = true; - completed_.notify_one(); - } - } - } // while running_ - } -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/common/threadpool.cc b/onnxruntime/core/common/threadpool.cc index 07305a41d0645..6cdcb3add7cf0 100644 --- a/onnxruntime/core/common/threadpool.cc +++ b/onnxruntime/core/common/threadpool.cc @@ -6,174 +6,31 @@ #include -#ifdef USE_EIGEN_THREADPOOL -#if defined(_MSC_VER) -#pragma warning(disable : 4267) -#endif - #if defined(__GNUC__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" +#else +#pragma warning(push) +#pragma warning(disable : 4267) #endif -#include +#include #if defined(__GNUC__) #pragma GCC diagnostic pop -#endif #else -#include "task_thread_pool.h" +#pragma warning(pop) #endif +using Eigen::Barrier; + namespace onnxruntime { namespace concurrency { - -// TODO: This is temporarily taken from Eigen until we upgrade its version. -// Barrier is an object that allows one or more threads to wait until -// Notify has been called a specified number of times. -class Barrier { - public: - Barrier(unsigned int count) : state_(count << 1), notified_(false) { - assert(((count << 1) >> 1) == count); - } - ~Barrier() { - assert((state_ >> 1) == 0); - } - - void Notify() { - unsigned int v = state_.fetch_sub(2, std::memory_order_acq_rel) - 2; - if (v != 1) { - assert(((v + 2) & ~1) != 0); - return; // either count has not dropped to 0, or waiter is not waiting - } - std::unique_lock l(mu_); - assert(!notified_); - notified_ = true; - cv_.notify_all(); - } - - void Wait() { - unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel); - if ((v >> 1) == 0) return; - std::unique_lock l(mu_); - while (!notified_) { - cv_.wait(l); - } - } - - private: - std::mutex mu_; - std::condition_variable cv_; - std::atomic state_; // low bit is waiter flag - bool notified_; -}; - -#ifdef USE_EIGEN_THREADPOOL -class ThreadPool::Impl : public Eigen::ThreadPool { - public: - Impl(const std::string& name, int num_threads) - : Eigen::ThreadPool(num_threads) { - ORT_UNUSED_PARAMETER(name); - } - - void ParallelFor(int32_t total, std::function fn) { - // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism - // We will simply rely on the work queue and stealing in the short term. - Barrier barrier(static_cast(total - 1)); - std::function handle_iteration = [&barrier, &fn](int iteration) { - fn(iteration); - barrier.Notify(); - }; - - for (int32_t id = 1; id < total; ++id) { - Schedule([=, &handle_iteration]() { handle_iteration(id); }); - } - - fn(0); - barrier.Wait(); - } - - void ParallelForRange(int64_t first, int64_t last, std::function fn) { - // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism - // We will simply rely on the work queue and stealing in the short term. - Barrier barrier(static_cast(last - first)); - std::function handle_range = [&barrier, &fn](int64_t first, int64_t last) { - fn(first, last); - barrier.Notify(); - }; - - for (int64_t id = first + 1; id <= last; ++id) { - Schedule([=, &handle_range]() { handle_range(id, id + 1); }); - } - - fn(first, first + 1); - barrier.Wait(); - } -}; -#else -class ThreadPool::Impl : public TaskThreadPool { - public: - Impl(const std::string& name, int num_threads) - : TaskThreadPool(num_threads) { - ORT_UNUSED_PARAMETER(name); - } - - void Schedule(std::function fn) { - std::packaged_task task(fn); - RunTask(std::move(task)); - } - - void ParallelFor(int32_t total, std::function fn) { -#ifdef USE_OPENMP -#pragma omp parallel for - for (int32_t id = 0; id < total; ++id) { - fn(id); - } -#else - Barrier barrier(static_cast(total - 1)); - std::function handle_iteration = [&barrier, &fn](int iteration) { - fn(iteration); - barrier.Notify(); - }; - for (int32_t id = 1; id < total; ++id) { - std::packaged_task task(std::bind(handle_iteration, id)); - RunTask(std::move(task)); - } - fn(0); - barrier.Wait(); -#endif - } - - void ParallelForRange(int64_t first, int64_t last, std::function fn) { -#ifdef USE_OPENMP -#pragma omp parallel for - for (int64_t id = first; id < last; ++id) { - fn(id, id + 1); - } -#else - Barrier barrier(static_cast(last - first)); - std::function handle_iteration = [&barrier, &fn](int64_t first, int64_t last) { - fn(first, last); - barrier.Notify(); - }; - for (int64_t id = first + 1; id < last; ++id) { - std::packaged_task task(std::bind(handle_iteration, id, id + 1)); - RunTask(std::move(task)); - } - fn(first, first + 1); - barrier.Wait(); -#endif - } -}; -#endif - // // ThreadPool // -ThreadPool::ThreadPool(const std::string& name, int num_threads) - : impl_(std::make_unique(name, num_threads)) { -} +ThreadPool::ThreadPool(const std::string&, int num_threads) : impl_(num_threads) {} -void ThreadPool::Schedule(std::function fn) { impl_->Schedule(fn); } +void ThreadPool::Schedule(std::function fn) { impl_.Schedule(fn); } void ThreadPool::ParallelFor(int32_t total, std::function fn) { if (total <= 0) return; @@ -183,7 +40,20 @@ void ThreadPool::ParallelFor(int32_t total, std::function fn) { return; } - impl_->ParallelFor(total, fn); + // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism + // We will simply rely on the work queue and stealing in the short term. + Barrier barrier(static_cast(total - 1)); + std::function handle_iteration = [&barrier, &fn](int iteration) { + fn(iteration); + barrier.Notify(); + }; + + for (int32_t id = 1; id < total; ++id) { + Schedule([=, &handle_iteration]() { handle_iteration(id); }); + } + + fn(0); + barrier.Wait(); } void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::function fn) { @@ -193,18 +63,28 @@ void ThreadPool::ParallelForRange(int64_t first, int64_t last, std::functionParallelForRange(first, last, fn); + // TODO: Eigen supports a more efficient ThreadPoolDevice mechanism + // We will simply rely on the work queue and stealing in the short term. + Barrier barrier(static_cast(last - first)); + std::function handle_range = [&barrier, &fn](int64_t first, int64_t last) { + fn(first, last); + barrier.Notify(); + }; + + for (int64_t id = first + 1; id <= last; ++id) { + Schedule([=, &handle_range]() { handle_range(id, id + 1); }); + } + + fn(first, first + 1); + barrier.Wait(); } // void ThreadPool::SetStealPartitions(const std::vector>& partitions) { // impl_->SetStealPartitions(partitions); // } -int ThreadPool::NumThreads() const { return impl_->NumThreads(); } - -int ThreadPool::CurrentThreadId() const { return impl_->CurrentThreadId(); } - -ThreadPool::~ThreadPool() {} +int ThreadPool::NumThreads() const { return impl_.NumThreads(); } +int ThreadPool::CurrentThreadId() const { return impl_.CurrentThreadId(); } } // namespace concurrency } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/onehot.cc b/onnxruntime/core/providers/cpu/tensor/onehot.cc index 9d9b1cff7470c..c4f0c2479a069 100644 --- a/onnxruntime/core/providers/cpu/tensor/onehot.cc +++ b/onnxruntime/core/providers/cpu/tensor/onehot.cc @@ -18,8 +18,9 @@ limitations under the License. #include "core/util/eigen_common_wrapper.h" #include "core/platform/env.h" +#ifndef EIGEN_USE_THREADS #define EIGEN_USE_THREADS - +#endif using namespace ::onnxruntime::common; using namespace std;