Skip to content

Commit

Permalink
Use function-local static variable instead.
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Jan 7, 2022
1 parent c73f198 commit 0a68c6a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 73 deletions.
46 changes: 0 additions & 46 deletions onnxruntime/core/common/shared_ptr_thread_safe_wrapper.h

This file was deleted.

20 changes: 11 additions & 9 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/cuda_execution_provider.h"
#include "core/common/shared_ptr_thread_safe_wrapper.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cuda_allocator.h"
#include "core/providers/cuda/cuda_fence.h"
Expand Down Expand Up @@ -2074,20 +2073,23 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

} // namespace cuda

static std::shared_ptr<KernelRegistry> CreateCudaKernelRegistry() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(cuda::RegisterCudaKernels(*registry));
return registry;
}
static std::shared_ptr<KernelRegistry>& CudaKernelRegistry() {
// static local variable ensures thread-safe initialization
static std::shared_ptr<KernelRegistry> cuda_kernel_registry = []() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(cuda::RegisterCudaKernels(*registry));
return registry;
}();

static SharedPtrThreadSafeWrapper<KernelRegistry> s_kernel_registry{&CreateCudaKernelRegistry};
return cuda_kernel_registry;
}

void Shutdown_DeleteRegistry() {
s_kernel_registry.Reset();
CudaKernelRegistry().reset();
}

std::shared_ptr<KernelRegistry> CUDAExecutionProvider::GetKernelRegistry() const {
return s_kernel_registry.GetInitialized();
return CudaKernelRegistry();
}

static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
Expand Down
20 changes: 11 additions & 9 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_execution_provider.h"
#include "core/common/shared_ptr_thread_safe_wrapper.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/rocm_allocator.h"
#include "core/providers/rocm/rocm_fence.h"
Expand Down Expand Up @@ -2067,20 +2066,23 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {

} // namespace rocm

static std::shared_ptr<KernelRegistry> CreateRocmKernelRegistry() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(rocm::RegisterRocmKernels(*registry));
return registry;
}
static std::shared_ptr<KernelRegistry>& RocmKernelRegistry() {
// static local variable ensures thread-safe initialization
static std::shared_ptr<KernelRegistry> rocm_kernel_registry = []() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(rocm::RegisterRocmKernels(*registry));
return registry;
}();

static SharedPtrThreadSafeWrapper<KernelRegistry> s_kernel_registry{&CreateRocmKernelRegistry};
return rocm_kernel_registry;
}

void Shutdown_DeleteRegistry() {
s_kernel_registry.Reset();
RocmKernelRegistry().reset();
}

std::shared_ptr<KernelRegistry> ROCMExecutionProvider::GetKernelRegistry() const {
return s_kernel_registry.GetInitialized();
return RocmKernelRegistry();
}

static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
Expand Down
20 changes: 11 additions & 9 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#define ORT_API_MANUAL_INIT
#include "core/session/onnxruntime_cxx_api.h"
#include "core/common/safeint.h"
#include "core/common/shared_ptr_thread_safe_wrapper.h"
#include "tensorrt_execution_provider.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
Expand Down Expand Up @@ -377,20 +376,23 @@ static Status RegisterTensorrtKernels(KernelRegistry& kernel_registry) {
return Status::OK();
}

static std::shared_ptr<KernelRegistry> CreateTensorrtKernelRegistry() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(RegisterTensorrtKernels(*registry));
return registry;
}
static std::shared_ptr<KernelRegistry>& TensorrtKernelRegistry() {
// static local variable ensures thread-safe initialization
static std::shared_ptr<KernelRegistry> tensorrt_kernel_registry = []() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(RegisterTensorrtKernels(*registry));
return registry;
}();

static SharedPtrThreadSafeWrapper<KernelRegistry> s_kernel_registry{&CreateTensorrtKernelRegistry};
return tensorrt_kernel_registry;
}

void Shutdown_DeleteRegistry() {
s_kernel_registry.Reset();
TensorrtKernelRegistry().reset();
}

std::shared_ptr<KernelRegistry> TensorrtExecutionProvider::GetKernelRegistry() const {
return s_kernel_registry.GetInitialized();
return TensorrtKernelRegistry();
}

// Per TensorRT documentation, logger needs to be a singleton.
Expand Down

0 comments on commit 0a68c6a

Please sign in to comment.