From d52fc0019e3ba448a3ad8b66a560f287ba015e9a Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 1 Aug 2019 12:07:40 -0700 Subject: [PATCH 1/3] Fix race condition issue in RNN/LSTM/GRU. Description: The filter_desc and rnn_desc could also be changed in compute which could be in multi-thread. It will cause race condition issue. Fix: Add lock scope to the place where it changes and uses these two descriptors. --- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 125 +++++++++--------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 36 +++-- onnxruntime/core/providers/cuda/rnn/gru.h | 2 +- onnxruntime/core/providers/cuda/rnn/lstm.h | 2 +- onnxruntime/core/providers/cuda/rnn/rnn.h | 4 +- 5 files changed, 93 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index e45eb16dc5508..5bf396c5c9272 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -52,15 +52,15 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, int bias_offset = 0; for (int layer = 0; layer < num_layers_ * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, W_lin_layer_id_[idx], W_data, w_offset, true); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, W_lin_layer_id_[idx], W_data, w_offset, true); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, W_lin_layer_id_[idx], B_data, bias_offset, false); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, W_lin_layer_id_[idx], B_data, bias_offset, false); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, R_lin_layer_id_[idx], R_data, r_offset, true); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, R_lin_layer_id_[idx], R_data, r_offset, true); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, R_lin_layer_id_[idx], B_data, bias_offset, false); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, R_lin_layer_id_[idx], B_data, bias_offset, false); } } } @@ -86,8 +86,8 @@ Status CudnnRnnBase::SetCudnnRnnDesc() { cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size_); state_buffer_ = GetScratchBuffer(state_size_); cudnn_dropout_desc_.Set(CudnnHandle(), state_buffer_.get(), state_size_); - ORT_RETURN_IF_ERROR(rnn_desc_.Set(CudnnHandle(), hidden_size_, num_layers_, cudnn_dropout_desc_, - cudnn_direction, rnn_mode_, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(rnn_descriptors_.rnn_desc.Set(CudnnHandle(), hidden_size_, num_layers_, cudnn_dropout_desc_, + cudnn_direction, rnn_mode_, CudnnTensor::GetDataType())); return Status::OK(); } @@ -123,8 +123,8 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons const T* R_data = R->template Data(); const T* B_data = B == nullptr ? nullptr : B->template Data(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_desc_, fake_x_desc, target_w_desc, - target_w_data.get(), W_data, R_data, B_data)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_descriptors_.rnn_desc, fake_x_desc, target_w_desc, + target_w_data.get(), W_data, R_data, B_data)); return Status::OK(); } @@ -198,16 +198,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Prepare the weight data - IAllocatorUniquePtr w_data; - CudnnFilterDescriptor w_desc; - if (!weight_cached_) { - const Tensor& W = *ctx->Input(RNN_Input_Index::W); - const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); - ReorganizeWeights(&W, &R, B, w_data, w_desc); - } - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->template Data(); if (reverse_) { @@ -239,42 +229,31 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->template Data(); - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED)); - - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc_, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes); - - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), - rnn_desc_, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); - } else { - CudnnDataTensor x_desc; - x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); - CudnnDataTensor y_desc; - y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), - rnn_desc_, - x_desc, + { + std::lock_guard lock(rnn_descriptors_.mutex); + + // Prepare the weight data + IAllocatorUniquePtr w_data; + CudnnFilterDescriptor w_desc; + if (!weight_cached_) { + const Tensor& W = *ctx->Input(RNN_Input_Index::W); + const Tensor& R = *ctx->Input(RNN_Input_Index::R); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + ReorganizeWeights(&W, &R, B, w_data, w_desc); + } + + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_descriptors_.rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + + size_t workspace_bytes; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_descriptors_.rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes); + + if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { + CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), + rnn_descriptors_.rnn_desc, + gsl::narrow_cast(seq_length), + x_desc.data(), x_data_input, hx_desc, hx_data, @@ -282,22 +261,48 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { cx_data, weight_cached_ ? w_desc_cache_ : w_desc, weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc, + y_desc.data(), y_data, y_h_desc, y_h_data, y_c_desc, y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, workspace_cuda.get(), workspace_bytes)); - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { - return Status::OK(); + } else { + CudnnDataTensor x_desc; + x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); + CudnnDataTensor y_desc; + y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), + rnn_descriptors_.rnn_desc, + x_desc, + x_data_input, + hx_desc, + hx_data, + cx_desc, + cx_data, + weight_cached_ ? w_desc_cache_ : w_desc, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + y_desc, + y_data, + y_h_desc, + y_h_data, + y_c_desc, + y_c_data, + nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, + workspace_cuda.get(), + workspace_bytes)); + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. + if (nullptr == Y) { + return Status::OK(); + } } } + IAllocatorUniquePtr y_reorganized_data; if (reverse_ || num_directions_ == 2) { //reverse output diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 0afd35435cc7c..2c41aee3cb02a 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -23,24 +23,24 @@ enum RNN_Input_Index { class CudnnRNN { public: - CudnnRNN() : rnn_desc_(nullptr) { + CudnnRNN() : cudnn_rnn_desc_(nullptr) { } ~CudnnRNN() { - if (rnn_desc_ != nullptr) { - cudnnDestroyRNNDescriptor(rnn_desc_); - rnn_desc_ = nullptr; + if (cudnn_rnn_desc_ != nullptr) { + cudnnDestroyRNNDescriptor(cudnn_rnn_desc_); + cudnn_rnn_desc_ = nullptr; } } Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType) { - if (!rnn_desc_) - CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&rnn_desc_)); + if (!cudnn_rnn_desc_) + CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor(cudnnHandle, - rnn_desc_, + cudnn_rnn_desc_, gsl::narrow_cast(hidden_size), num_layers, cudnn_dropout_desc, @@ -54,11 +54,18 @@ class CudnnRNN { } operator cudnnRNNDescriptor_t() const { - return rnn_desc_; + return cudnn_rnn_desc_; } private: - cudnnRNNDescriptor_t rnn_desc_; + cudnnRNNDescriptor_t cudnn_rnn_desc_; +}; + +struct RNNDescriptors { + CudnnFilterDescriptor filter_desc; + CudnnRNN rnn_desc; + + OrtMutex mutex; }; template @@ -85,6 +92,8 @@ class CudnnRnnBase : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; + void SetRNNMode(cudnnRNNMode_t rnn_mode) { rnn_mode_ = rnn_mode; } + private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, @@ -115,19 +124,22 @@ class CudnnRnnBase : public CudaKernel { int64_t num_directions_; // required int64_t hidden_size_; - cudnnRNNMode_t rnn_mode_; std::vector W_lin_layer_id_; std::vector R_lin_layer_id_; - CudnnRNN rnn_desc_; bool reverse_; int num_layers_; private: + cudnnRNNMode_t rnn_mode_; // optional std::string direction_; + // w_desc_cache_ & cudnn_dropout_desc_ are changed in Constructor only CudnnFilterDescriptor w_desc_cache_; CudnnDropout cudnn_dropout_desc_; - CudnnFilterDescriptor filter_desc_; + + // filter_desc & rnn_desc could be changed in Compute + mutable RNNDescriptors rnn_descriptors_; + IAllocatorUniquePtr w_data_cache_; bool weight_cached_; IAllocatorUniquePtr state_buffer_; diff --git a/onnxruntime/core/providers/cuda/rnn/gru.h b/onnxruntime/core/providers/cuda/rnn/gru.h index 43a0ba4ab5878..186ace6a80292 100644 --- a/onnxruntime/core/providers/cuda/rnn/gru.h +++ b/onnxruntime/core/providers/cuda/rnn/gru.h @@ -15,7 +15,7 @@ template class GRU final : public CudnnRnnBase { public: GRU(const OpKernelInfo& info) : CudnnRnnBase(info) { - CudnnRnnBase::rnn_mode_ = CUDNN_GRU; + CudnnRnnBase::SetRNNMode(CUDNN_GRU); CudnnRnnBase::SetCudnnRnnDesc(); // ONNX W layout is Wzrh, WBzrh, mapping to RNNLinLayerMatrixParams the linLayerID is 1, 0, 2 diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.h b/onnxruntime/core/providers/cuda/rnn/lstm.h index 3ba719d61750d..61a70e8a5730f 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.h +++ b/onnxruntime/core/providers/cuda/rnn/lstm.h @@ -13,7 +13,7 @@ class LSTM final : public CudnnRnnBase { public: LSTM(const OpKernelInfo& info) : CudnnRnnBase(info) { - CudnnRnnBase::rnn_mode_ = CUDNN_LSTM; + CudnnRnnBase::SetRNNMode(CUDNN_LSTM); CudnnRnnBase::SetCudnnRnnDesc(); // ONNX W layout is W[iofc], WB[iofc], mapping to RNNLinLayerMatrixParams the linLayerID is 0, 3, 1, 2 diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index 246e8d1062df0..9113897bcc99e 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -20,9 +20,9 @@ class RNN final : public CudnnRnnBase { std::vector activations_; ORT_ENFORCE(info.GetAttrs("activations", activations_).IsOK()); if (activations_[0] == "Relu") - CudnnRnnBase::rnn_mode_ = CUDNN_RNN_RELU; + CudnnRnnBase::SetRNNMode(CUDNN_RNN_RELU); else if (activations_[0] == "Tanh") - CudnnRnnBase::rnn_mode_ = CUDNN_RNN_TANH; + CudnnRnnBase::SetRNNMode(CUDNN_RNN_TANH); CudnnRnnBase::SetCudnnRnnDesc(); From d887a5ded3f8f0bb936d83c11d5fb306e5a44cd7 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 7 Aug 2019 11:33:50 -0700 Subject: [PATCH 2/3] create temperate cudnn descriptors instead of shared descriptor with lock --- .../core/providers/cuda/cudnn_common.h | 2 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 184 +++++++++--------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 56 +++--- onnxruntime/core/providers/cuda/rnn/gru.h | 1 - onnxruntime/core/providers/cuda/rnn/lstm.h | 1 - onnxruntime/core/providers/cuda/rnn/rnn.h | 2 - 6 files changed, 116 insertions(+), 130 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 02a3ba6b694bb..bfd233b68b65e 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -97,13 +97,13 @@ class CudnnDropout final { return dropout_desc_; } - private: Status CreateDescriptorIfNeeded() { if (!dropout_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateDropoutDescriptor(&dropout_desc_)); return Status::OK(); } + private: cudnnDropoutDescriptor_t dropout_desc_; }; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 5bf396c5c9272..e61aaee1c74cc 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -15,7 +15,7 @@ void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, const cudnnTensorDescriptor_t x_desc, const cudnnFilterDescriptor_t w_desc, const cudnnFilterDescriptor_t filter_desc, - const void* w_data, + const void* reorganized_w_data, const int lin_layer_id, const T* pos, int& offset, @@ -27,9 +27,9 @@ void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, T* mem_offset; if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, w_data, lin_layer_id, filter_desc, (void**)&mem_offset); + cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, w_data, lin_layer_id, filter_desc, (void**)&mem_offset); + cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); } cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); @@ -42,25 +42,25 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, const cudnnTensorDescriptor_t x_desc, const cudnnFilterDescriptor_t w_desc, - void* w_data, + void* reorganized_w_data, const T* W_data, const T* R_data, const T* B_data) const { - //Onnx only support 1 layer int w_offset = 0; int r_offset = 0; int bias_offset = 0; - for (int layer = 0; layer < num_layers_ * num_directions_; ++layer) { + CudnnFilterDescriptor filter_desc; + for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, W_lin_layer_id_[idx], W_data, w_offset, true); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, W_lin_layer_id_[idx], B_data, bias_offset, false); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, R_lin_layer_id_[idx], R_data, r_offset, true); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, R_lin_layer_id_[idx], B_data, bias_offset, false); + SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false); } } } @@ -68,34 +68,11 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, return Status::OK(); } -template -Status CudnnRnnBase::SetCudnnRnnDesc() { - typedef typename ToCudaType::MappedType CudaT; - - cudnnDirectionMode_t cudnn_direction = CUDNN_UNIDIRECTIONAL; - if (direction_ == "bidirectional") { - cudnn_direction = CUDNN_BIDIRECTIONAL; - } else if (direction_ == "forward") { - cudnn_direction = CUDNN_UNIDIRECTIONAL; - } else if (direction_ == "reverse") { - cudnn_direction = CUDNN_UNIDIRECTIONAL; - // need to reverse data - reverse_ = true; - } - - cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size_); - state_buffer_ = GetScratchBuffer(state_size_); - cudnn_dropout_desc_.Set(CudnnHandle(), state_buffer_.get(), state_size_); - ORT_RETURN_IF_ERROR(rnn_descriptors_.rnn_desc.Set(CudnnHandle(), hidden_size_, num_layers_, cudnn_dropout_desc_, - cudnn_direction, rnn_mode_, CudnnTensor::GetDataType())); - - return Status::OK(); -} - template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, - IAllocatorUniquePtr& target_w_data, - CudnnFilterDescriptor& target_w_desc) const { + IAllocatorUniquePtr& reorganized_w_data, + CudnnFilterDescriptor& target_w_desc, + CudnnRNN& rnn_desc) const { typedef typename ToCudaType::MappedType CudaT; int64_t input_size = W->Shape()[2]; // RNN W[num_directions_, hidden_size_, input_size] @@ -117,20 +94,21 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType()); // Prepare the weight data - target_w_data = GetScratchBuffer(w_size * sizeof(T)); + reorganized_w_data = GetScratchBuffer(w_size * sizeof(T)); const T* W_data = W->template Data(); const T* R_data = R->template Data(); const T* B_data = B == nullptr ? nullptr : B->template Data(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_descriptors_.rnn_desc, fake_x_desc, target_w_desc, - target_w_data.get(), W_data, R_data, B_data)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_desc, fake_x_desc, target_w_desc, + reorganized_w_data.get(), W_data, R_data, B_data)); return Status::OK(); } template Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { + typedef typename ToCudaType::MappedType CudaT; // Cache the weight const Tensor* W; const Tensor* R; @@ -140,10 +118,15 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); if (get_W && get_R) { + CudnnRNN tmp_rnn_desc; + CudnnDropout cudnn_dropout_desc; + cudnn_dropout_desc.CreateDescriptorIfNeeded(); + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc, + cudnn_direction_mode_, rnn_mode_, CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc)); } weight_cached_ = true; } @@ -173,7 +156,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // optional outputs std::vector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); - std::vector dims_hxy({num_layers_ * num_directions_, batch_size, hidden_size_}); + std::vector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); std::vector dims_yc{num_directions_, batch_size, hidden_size_}; Tensor* Y = ctx->Output(Output_Index::Y, dims_Y); Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); @@ -229,31 +212,66 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->template Data(); - { - std::lock_guard lock(rnn_descriptors_.mutex); - - // Prepare the weight data - IAllocatorUniquePtr w_data; - CudnnFilterDescriptor w_desc; - if (!weight_cached_) { - const Tensor& W = *ctx->Input(RNN_Input_Index::W); - const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); - ReorganizeWeights(&W, &R, B, w_data, w_desc); - } + //CudnnFilterDescriptor filter_desc; + CudnnRNN rnn_desc; + //CudnnDropout cudnn_dropout_desc; + IAllocatorUniquePtr state_buffer; + size_t state_size; + CudnnDropout cudnn_dropout_desc; + cudnn_dropout_desc.CreateDescriptorIfNeeded(); - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_descriptors_.rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + cudnn_dropout_desc.GetCudnnDropoutStatesSize(CudnnHandle(), state_size); + state_buffer = GetScratchBuffer(state_size); + cudnn_dropout_desc.Set(CudnnHandle(), state_buffer.get(), state_size); + ORT_RETURN_IF_ERROR(rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc, + cudnn_direction_mode_, rnn_mode_, CudnnTensor::GetDataType())); - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_descriptors_.rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes); + // Prepare the weight data + IAllocatorUniquePtr w_data; + CudnnFilterDescriptor w_desc; + if (!weight_cached_) { + const Tensor& W = *ctx->Input(RNN_Input_Index::W); + const Tensor& R = *ctx->Input(RNN_Input_Index::R); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc); + } - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), - rnn_descriptors_.rnn_desc, - gsl::narrow_cast(seq_length), - x_desc.data(), + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + + size_t workspace_bytes; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes); + + if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { + CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(), + rnn_desc, + gsl::narrow_cast(seq_length), + x_desc.data(), + x_data_input, + hx_desc, + hx_data, + cx_desc, + cx_data, + weight_cached_ ? w_desc_cache_ : w_desc, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + y_desc.data(), + y_data, + y_h_desc, + y_h_data, + y_c_desc, + y_c_data, + workspace_cuda.get(), + workspace_bytes)); + } else { + CudnnDataTensor x_desc; + x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); + CudnnDataTensor y_desc; + y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), + rnn_desc, + x_desc, x_data_input, hx_desc, hx_data, @@ -261,48 +279,22 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { cx_data, weight_cached_ ? w_desc_cache_ : w_desc, weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), + y_desc, y_data, y_h_desc, y_h_data, y_c_desc, y_c_data, + nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, workspace_cuda.get(), workspace_bytes)); - } else { - CudnnDataTensor x_desc; - x_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, sequence_lens_data); - CudnnDataTensor y_desc; - y_desc.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(), - rnn_descriptors_.rnn_desc, - x_desc, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_cuda.get(), - workspace_bytes)); - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { - return Status::OK(); - } + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. + if (nullptr == Y) { + return Status::OK(); } } - IAllocatorUniquePtr y_reorganized_data; if (reverse_ || num_directions_ == 2) { //reverse output diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 2c41aee3cb02a..f7d83b4d7c05a 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -21,6 +21,9 @@ enum RNN_Input_Index { initial_c = 6 }; +// Onnx RNN/GRU/LSTM only support 1 layer +const int RNN_NUM_LAYERS = 1; + class CudnnRNN { public: CudnnRNN() : cudnn_rnn_desc_(nullptr) { @@ -61,13 +64,6 @@ class CudnnRNN { cudnnRNNDescriptor_t cudnn_rnn_desc_; }; -struct RNNDescriptors { - CudnnFilterDescriptor filter_desc; - CudnnRNN rnn_desc; - - OrtMutex mutex; -}; - template class CudnnRnnBase : public CudaKernel { const std::set allowed_directions{"forward", "reverse", "bidirectional"}; @@ -75,19 +71,28 @@ class CudnnRnnBase : public CudaKernel { public: CudnnRnnBase(const OpKernelInfo& info) : CudaKernel{info} { reverse_ = false; - ORT_ENFORCE(info.GetAttr("direction", &direction_).IsOK()); - num_directions_ = direction_ == "bidirectional" ? 2 : 1; - ORT_ENFORCE(allowed_directions.find(direction_) != allowed_directions.end()); + std::string direction = "forward"; + direction = info.GetAttrOrDefault("direction", "forward"); + cudnn_direction_mode_ = CUDNN_UNIDIRECTIONAL; + if (direction == "bidirectional") { + cudnn_direction_mode_ = CUDNN_BIDIRECTIONAL; + } else if (direction == "forward") { + cudnn_direction_mode_ = CUDNN_UNIDIRECTIONAL; + } else if (direction == "reverse") { + cudnn_direction_mode_ = CUDNN_UNIDIRECTIONAL; + // need to reverse data + reverse_ = true; + } + + num_directions_ = cudnn_direction_mode_ == CUDNN_BIDIRECTIONAL ? 2 : 1; + ORT_ENFORCE(allowed_directions.find(direction) != allowed_directions.end()); ORT_ENFORCE(info.GetAttr("hidden_size", &hidden_size_).IsOK() && hidden_size_ > 0); rnn_mode_ = CUDNN_LSTM; - num_layers_ = 1; weight_cached_ = false; w_data_cache_ = nullptr; } - Status SetCudnnRnnDesc(); - Status CacheCudnnRnnWeights(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* ctx) const override; @@ -106,7 +111,8 @@ class CudnnRnnBase : public CudaKernel { Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, IAllocatorUniquePtr& target_w_data, - CudnnFilterDescriptor& target_w_desc) const; + CudnnFilterDescriptor& target_w_desc, + CudnnRNN& rnn_desc) const; void SetWeightBias(const cudnnHandle_t handle, const cudnnRNNDescriptor_t rnn_desc, @@ -121,29 +127,21 @@ class CudnnRnnBase : public CudaKernel { bool is_matrix) const; protected: - int64_t num_directions_; - // required - int64_t hidden_size_; + // W_lin_layer_id_ & R_lin_layer_id_ are set in Constructor std::vector W_lin_layer_id_; std::vector R_lin_layer_id_; - bool reverse_; - int num_layers_; private: + cudnnDirectionMode_t cudnn_direction_mode_; + bool reverse_; + int64_t num_directions_; + // hidden_size_ from attribute + int64_t hidden_size_; cudnnRNNMode_t rnn_mode_; - // optional - std::string direction_; - // w_desc_cache_ & cudnn_dropout_desc_ are changed in Constructor only + // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; - CudnnDropout cudnn_dropout_desc_; - - // filter_desc & rnn_desc could be changed in Compute - mutable RNNDescriptors rnn_descriptors_; - IAllocatorUniquePtr w_data_cache_; bool weight_cached_; - IAllocatorUniquePtr state_buffer_; - size_t state_size_; enum Output_Index { Y = 0, diff --git a/onnxruntime/core/providers/cuda/rnn/gru.h b/onnxruntime/core/providers/cuda/rnn/gru.h index 186ace6a80292..ab9dabff5db36 100644 --- a/onnxruntime/core/providers/cuda/rnn/gru.h +++ b/onnxruntime/core/providers/cuda/rnn/gru.h @@ -16,7 +16,6 @@ class GRU final : public CudnnRnnBase { public: GRU(const OpKernelInfo& info) : CudnnRnnBase(info) { CudnnRnnBase::SetRNNMode(CUDNN_GRU); - CudnnRnnBase::SetCudnnRnnDesc(); // ONNX W layout is Wzrh, WBzrh, mapping to RNNLinLayerMatrixParams the linLayerID is 1, 0, 2 CudnnRnnBase::W_lin_layer_id_.assign({1, 0, 2}); diff --git a/onnxruntime/core/providers/cuda/rnn/lstm.h b/onnxruntime/core/providers/cuda/rnn/lstm.h index 61a70e8a5730f..3ed12cfa7fff9 100644 --- a/onnxruntime/core/providers/cuda/rnn/lstm.h +++ b/onnxruntime/core/providers/cuda/rnn/lstm.h @@ -14,7 +14,6 @@ class LSTM final : public CudnnRnnBase { public: LSTM(const OpKernelInfo& info) : CudnnRnnBase(info) { CudnnRnnBase::SetRNNMode(CUDNN_LSTM); - CudnnRnnBase::SetCudnnRnnDesc(); // ONNX W layout is W[iofc], WB[iofc], mapping to RNNLinLayerMatrixParams the linLayerID is 0, 3, 1, 2 CudnnRnnBase::W_lin_layer_id_.assign({0, 3, 1, 2}); diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index 9113897bcc99e..dbb0d2843fe11 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -24,8 +24,6 @@ class RNN final : public CudnnRnnBase { else if (activations_[0] == "Tanh") CudnnRnnBase::SetRNNMode(CUDNN_RNN_TANH); - CudnnRnnBase::SetCudnnRnnDesc(); - // ONNX W mapping to RNNLinLayerMatrixParams the linLayerID is 0 CudnnRnnBase::W_lin_layer_id_.assign({0}); // ONNX R mapping to RNNLinLayerMatrixParams the linLayerID is 1 From 7e8d75c5b6061b7bfb41bf828b2ba1e47d98ab1d Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 8 Aug 2019 12:02:34 -0700 Subject: [PATCH 3/3] cache cudnn_dropout_desc_ which won't change --- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 16 ++-------------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 10 ++++++++++ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index e61aaee1c74cc..26d87bcdeef00 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -119,9 +119,7 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { if (get_W && get_R) { CudnnRNN tmp_rnn_desc; - CudnnDropout cudnn_dropout_desc; - cudnn_dropout_desc.CreateDescriptorIfNeeded(); - ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc, + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, CudnnTensor::GetDataType())); if (get_B) { ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc)); @@ -212,18 +210,8 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->template Data(); - //CudnnFilterDescriptor filter_desc; CudnnRNN rnn_desc; - //CudnnDropout cudnn_dropout_desc; - IAllocatorUniquePtr state_buffer; - size_t state_size; - CudnnDropout cudnn_dropout_desc; - cudnn_dropout_desc.CreateDescriptorIfNeeded(); - - cudnn_dropout_desc.GetCudnnDropoutStatesSize(CudnnHandle(), state_size); - state_buffer = GetScratchBuffer(state_size); - cudnn_dropout_desc.Set(CudnnHandle(), state_buffer.get(), state_size); - ORT_RETURN_IF_ERROR(rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc, + ORT_RETURN_IF_ERROR(rnn_desc.Set(CudnnHandle(), hidden_size_, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, CudnnTensor::GetDataType())); // Prepare the weight data diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index f7d83b4d7c05a..3a4e234c19aad 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -91,6 +91,12 @@ class CudnnRnnBase : public CudaKernel { rnn_mode_ = CUDNN_LSTM; weight_cached_ = false; w_data_cache_ = nullptr; + + size_t state_size; + cudnn_dropout_desc_.CreateDescriptorIfNeeded(); + cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size); + state_buffer_ = GetScratchBuffer(state_size); + cudnn_dropout_desc_.Set(CudnnHandle(), state_buffer_.get(), state_size); } Status CacheCudnnRnnWeights(const OpKernelInfo& info); @@ -143,6 +149,10 @@ class CudnnRnnBase : public CudaKernel { IAllocatorUniquePtr w_data_cache_; bool weight_cached_; + // cudnn_dropout_desc_ is a cache, never to be changed + IAllocatorUniquePtr state_buffer_; + CudnnDropout cudnn_dropout_desc_; + enum Output_Index { Y = 0, Y_h = 1,