diff --git a/onnxruntime/core/common/shared_ptr_thread_safe_wrapper.h b/onnxruntime/core/common/shared_ptr_thread_safe_wrapper.h deleted file mode 100644 index c1da450166949..0000000000000 --- a/onnxruntime/core/common/shared_ptr_thread_safe_wrapper.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/platform/ort_mutex.h" - -namespace onnxruntime { - -// provides limited thread-safe access to a shared_ptr -// the underlying shared_ptr can, in a thread-safe manner: -// - be copied, holding a value which is initialized on demand, with GetInitialized() -// - be reset with Reset() -// `init_fn` is called to obtain the initialized value -template -class SharedPtrThreadSafeWrapper { - public: - using InitFn = std::function()>; - - explicit SharedPtrThreadSafeWrapper(InitFn init_fn) : init_fn_{init_fn} {} - - std::shared_ptr GetInitialized() { - std::scoped_lock lock{ptr_mutex_}; - if (!ptr_) { - ptr_ = init_fn_(); - } - return ptr_; - } - - void Reset() { - std::scoped_lock lock{ptr_mutex_}; - ptr_.reset(); - } - - private: - InitFn init_fn_; - - OrtMutex ptr_mutex_; - std::shared_ptr ptr_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index e3e15f019b7fa..9d48f8fe5f1a1 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -3,7 +3,6 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/cuda_execution_provider.h" -#include "core/common/shared_ptr_thread_safe_wrapper.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_allocator.h" #include "core/providers/cuda/cuda_fence.h" @@ -2074,20 +2073,23 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { } // namespace cuda -static std::shared_ptr CreateCudaKernelRegistry() { - std::shared_ptr registry = KernelRegistry::Create(); - ORT_THROW_IF_ERROR(cuda::RegisterCudaKernels(*registry)); - return registry; -} +static std::shared_ptr& CudaKernelRegistry() { + // static local variable ensures thread-safe initialization + static std::shared_ptr cuda_kernel_registry = []() { + std::shared_ptr registry = KernelRegistry::Create(); + ORT_THROW_IF_ERROR(cuda::RegisterCudaKernels(*registry)); + return registry; + }(); -static SharedPtrThreadSafeWrapper s_kernel_registry{&CreateCudaKernelRegistry}; + return cuda_kernel_registry; +} void Shutdown_DeleteRegistry() { - s_kernel_registry.Reset(); + CudaKernelRegistry().reset(); } std::shared_ptr CUDAExecutionProvider::GetKernelRegistry() const { - return s_kernel_registry.GetInitialized(); + return CudaKernelRegistry(); } static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 5a843d30fe85b..9f9ebdb3a868a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -3,7 +3,6 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/rocm/rocm_execution_provider.h" -#include "core/common/shared_ptr_thread_safe_wrapper.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/rocm_allocator.h" #include "core/providers/rocm/rocm_fence.h" @@ -2067,20 +2066,23 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { } // namespace rocm -static std::shared_ptr CreateRocmKernelRegistry() { - std::shared_ptr registry = KernelRegistry::Create(); - ORT_THROW_IF_ERROR(rocm::RegisterRocmKernels(*registry)); - return registry; -} +static std::shared_ptr& RocmKernelRegistry() { + // static local variable ensures thread-safe initialization + static std::shared_ptr rocm_kernel_registry = []() { + std::shared_ptr registry = KernelRegistry::Create(); + ORT_THROW_IF_ERROR(rocm::RegisterRocmKernels(*registry)); + return registry; + }(); -static SharedPtrThreadSafeWrapper s_kernel_registry{&CreateRocmKernelRegistry}; + return rocm_kernel_registry; +} void Shutdown_DeleteRegistry() { - s_kernel_registry.Reset(); + RocmKernelRegistry().reset(); } std::shared_ptr ROCMExecutionProvider::GetKernelRegistry() const { - return s_kernel_registry.GetInitialized(); + return RocmKernelRegistry(); } static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 2f0b0969e9acf..10500e98b7323 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -7,7 +7,6 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" -#include "core/common/shared_ptr_thread_safe_wrapper.h" #include "tensorrt_execution_provider.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" @@ -377,20 +376,23 @@ static Status RegisterTensorrtKernels(KernelRegistry& kernel_registry) { return Status::OK(); } -static std::shared_ptr CreateTensorrtKernelRegistry() { - std::shared_ptr registry = KernelRegistry::Create(); - ORT_THROW_IF_ERROR(RegisterTensorrtKernels(*registry)); - return registry; -} +static std::shared_ptr& TensorrtKernelRegistry() { + // static local variable ensures thread-safe initialization + static std::shared_ptr tensorrt_kernel_registry = []() { + std::shared_ptr registry = KernelRegistry::Create(); + ORT_THROW_IF_ERROR(RegisterTensorrtKernels(*registry)); + return registry; + }(); -static SharedPtrThreadSafeWrapper s_kernel_registry{&CreateTensorrtKernelRegistry}; + return tensorrt_kernel_registry; +} void Shutdown_DeleteRegistry() { - s_kernel_registry.Reset(); + TensorrtKernelRegistry().reset(); } std::shared_ptr TensorrtExecutionProvider::GetKernelRegistry() const { - return s_kernel_registry.GetInitialized(); + return TensorrtKernelRegistry(); } // Per TensorRT documentation, logger needs to be a singleton.