Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in CUDA, ROCm, and TensorRT EP GetKernelRegistry() implementations. #10200

Merged
merged 7 commits into from
Mar 2, 2022
23 changes: 12 additions & 11 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2073,22 +2073,23 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

} // namespace cuda

static std::shared_ptr<onnxruntime::KernelRegistry> s_kernel_registry;
static std::shared_ptr<KernelRegistry>& CudaKernelRegistry() {
// static local variable ensures thread-safe initialization
static std::shared_ptr<KernelRegistry> cuda_kernel_registry = []() {
std::shared_ptr<KernelRegistry> registry = KernelRegistry::Create();
ORT_THROW_IF_ERROR(cuda::RegisterCudaKernels(*registry));
return registry;
}();

return cuda_kernel_registry;
}

void Shutdown_DeleteRegistry() {
s_kernel_registry.reset();
CudaKernelRegistry().reset();
}

std::shared_ptr<KernelRegistry> CUDAExecutionProvider::GetKernelRegistry() const {
if (!s_kernel_registry) {
s_kernel_registry = KernelRegistry::Create();
auto status = cuda::RegisterCudaKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}

return s_kernel_registry;
return CudaKernelRegistry();
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
}

static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
Expand Down
Loading