Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition issue in RNN/LSTM/GRU #1544

Merged
merged 5 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 65 additions & 60 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ Status CudnnRnnBase<T>::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
int bias_offset = 0;
for (int layer = 0; layer < num_layers_ * num_directions_; ++layer) {
for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) {
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, W_lin_layer_id_[idx], W_data, w_offset, true);
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, W_lin_layer_id_[idx], W_data, w_offset, true);
if (B_data != nullptr) {
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, W_lin_layer_id_[idx], B_data, bias_offset, false);
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, W_lin_layer_id_[idx], B_data, bias_offset, false);
}
}
for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) {
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, R_lin_layer_id_[idx], R_data, r_offset, true);
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, R_lin_layer_id_[idx], R_data, r_offset, true);
if (B_data != nullptr) {
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc_, w_data, R_lin_layer_id_[idx], B_data, bias_offset, false);
SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, rnn_descriptors_.filter_desc, w_data, R_lin_layer_id_[idx], B_data, bias_offset, false);
}
}
}
Expand All @@ -86,8 +86,8 @@ Status CudnnRnnBase<T>::SetCudnnRnnDesc() {
cudnn_dropout_desc_.GetCudnnDropoutStatesSize(CudnnHandle(), state_size_);
state_buffer_ = GetScratchBuffer<void>(state_size_);
cudnn_dropout_desc_.Set(CudnnHandle(), state_buffer_.get(), state_size_);
ORT_RETURN_IF_ERROR(rnn_desc_.Set(CudnnHandle(), hidden_size_, num_layers_, cudnn_dropout_desc_,
cudnn_direction, rnn_mode_, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(rnn_descriptors_.rnn_desc.Set(CudnnHandle(), hidden_size_, num_layers_, cudnn_dropout_desc_,
cudnn_direction, rnn_mode_, CudnnTensor::GetDataType<CudaT>()));

return Status::OK();
}
Expand Down Expand Up @@ -123,8 +123,8 @@ Status CudnnRnnBase<T>::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
const T* R_data = R->template Data<T>();
const T* B_data = B == nullptr ? nullptr : B->template Data<T>();

ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_desc_, fake_x_desc, target_w_desc,
target_w_data.get(), W_data, R_data, B_data));
ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(CudnnHandle(), rnn_descriptors_.rnn_desc, fake_x_desc, target_w_desc,
target_w_data.get(), W_data, R_data, B_data));

return Status::OK();
}
Expand Down Expand Up @@ -198,16 +198,6 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));
ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType<CudaT>()));

// Prepare the weight data
IAllocatorUniquePtr<void> w_data;
CudnnFilterDescriptor w_desc;
if (!weight_cached_) {
const Tensor& W = *ctx->Input<Tensor>(RNN_Input_Index::W);
const Tensor& R = *ctx->Input<Tensor>(RNN_Input_Index::R);
const Tensor* B = ctx->Input<Tensor>(RNN_Input_Index::B);
ReorganizeWeights(&W, &R, B, w_data, w_desc);
}

IAllocatorUniquePtr<T> x_reversed_data;
const T* x_data = X->template Data<T>();
if (reverse_) {
Expand Down Expand Up @@ -239,65 +229,80 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {

const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->template Data<int32_t>();

// CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED));

size_t workspace_bytes;
CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc_, gsl::narrow_cast<int>(seq_length), x_desc.data(), &workspace_bytes));
auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes);

if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
rnn_desc_,
gsl::narrow_cast<int>(seq_length),
x_desc.data(),
x_data_input,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc.data(),
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
workspace_cuda.get(),
workspace_bytes));
} else {
CudnnDataTensor x_desc;
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, sequence_lens_data);
CudnnDataTensor y_desc;
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data);

CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
rnn_desc_,
x_desc,
{
std::lock_guard<OrtMutex> lock(rnn_descriptors_.mutex);

// Prepare the weight data
IAllocatorUniquePtr<void> w_data;
CudnnFilterDescriptor w_desc;
if (!weight_cached_) {
const Tensor& W = *ctx->Input<Tensor>(RNN_Input_Index::W);
const Tensor& R = *ctx->Input<Tensor>(RNN_Input_Index::R);
const Tensor* B = ctx->Input<Tensor>(RNN_Input_Index::B);
ReorganizeWeights(&W, &R, B, w_data, w_desc);
}

// CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_descriptors_.rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED));

size_t workspace_bytes;
CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_descriptors_.rnn_desc, gsl::narrow_cast<int>(seq_length), x_desc.data(), &workspace_bytes));
auto workspace_cuda = GetScratchBuffer<void>(workspace_bytes);

if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) {
CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(CudnnHandle(),
rnn_descriptors_.rnn_desc,
gsl::narrow_cast<int>(seq_length),
x_desc.data(),
x_data_input,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc,
y_desc.data(),
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_cuda.get(),
workspace_bytes));
// Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
if (nullptr == Y) {
return Status::OK();
} else {
CudnnDataTensor x_desc;
x_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, input_size, sequence_lens_data);
CudnnDataTensor y_desc;
y_desc.Set(CudnnTensor::GetDataType<CudaT>(), seq_length, batch_size, hidden_size_ * num_directions_, sequence_lens_data);

CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(CudnnHandle(),
Copy link
Contributor

@ke1337 ke1337 Aug 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudnnRNNForwardInferenceEx [](start = 28, length = 26)

Seems inference is now protected by mutex. An alternative might be to always create descriptors in stack. We may need some perf test to find out which is better. #Resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we may need to try creating descriptor locally. if the perf is good, it would be better to implement a stateless Compute.


In reply to: 310228482 [](ancestors = 310228482)

Copy link
Contributor Author

@HectorSVC HectorSVC Aug 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented in another way to create descriptors in stack, and compared the performance, it's much slower than current implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm the slowdown is from creating tensor descriptor on stack? Is ReorganizeWeights is protected by mutex and shared between threads?


In reply to: 311792641 [](ancestors = 311792641)

rnn_descriptors_.rnn_desc,
x_desc,
x_data_input,
hx_desc,
hx_data,
cx_desc,
cx_data,
weight_cached_ ? w_desc_cache_ : w_desc,
weight_cached_ ? w_data_cache_.get() : w_data.get(),
y_desc,
y_data,
y_h_desc,
y_h_data,
y_c_desc,
y_c_data,
nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
workspace_cuda.get(),
workspace_bytes));
// Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data.
if (nullptr == Y) {
return Status::OK();
}
}
}


IAllocatorUniquePtr<T> y_reorganized_data;
if (reverse_ || num_directions_ == 2) {
//reverse output
Expand Down
36 changes: 24 additions & 12 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,24 @@ enum RNN_Input_Index {

class CudnnRNN {
public:
CudnnRNN() : rnn_desc_(nullptr) {
CudnnRNN() : cudnn_rnn_desc_(nullptr) {
}

~CudnnRNN() {
if (rnn_desc_ != nullptr) {
cudnnDestroyRNNDescriptor(rnn_desc_);
rnn_desc_ = nullptr;
if (cudnn_rnn_desc_ != nullptr) {
cudnnDestroyRNNDescriptor(cudnn_rnn_desc_);
cudnn_rnn_desc_ = nullptr;
}
}

Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers,
cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model,
cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType) {
if (!rnn_desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&rnn_desc_));
if (!cudnn_rnn_desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_));

CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor(cudnnHandle,
rnn_desc_,
cudnn_rnn_desc_,
gsl::narrow_cast<int>(hidden_size),
num_layers,
cudnn_dropout_desc,
Expand All @@ -54,11 +54,18 @@ class CudnnRNN {
}

operator cudnnRNNDescriptor_t() const {
return rnn_desc_;
return cudnn_rnn_desc_;
}

private:
cudnnRNNDescriptor_t rnn_desc_;
cudnnRNNDescriptor_t cudnn_rnn_desc_;
};

struct RNNDescriptors {
CudnnFilterDescriptor filter_desc;
CudnnRNN rnn_desc;

OrtMutex mutex;
};

template <typename T>
Expand All @@ -85,6 +92,8 @@ class CudnnRnnBase : public CudaKernel {

Status ComputeInternal(OpKernelContext* ctx) const override;

void SetRNNMode(cudnnRNNMode_t rnn_mode) { rnn_mode_ = rnn_mode; }

private:
Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
const cudnnRNNDescriptor_t rnn_desc,
Expand Down Expand Up @@ -115,19 +124,22 @@ class CudnnRnnBase : public CudaKernel {
int64_t num_directions_;
// required
int64_t hidden_size_;
cudnnRNNMode_t rnn_mode_;
std::vector<int> W_lin_layer_id_;
std::vector<int> R_lin_layer_id_;
CudnnRNN rnn_desc_;
bool reverse_;
int num_layers_;

private:
cudnnRNNMode_t rnn_mode_;
// optional
std::string direction_;
// w_desc_cache_ & cudnn_dropout_desc_ are changed in Constructor only
CudnnFilterDescriptor w_desc_cache_;
CudnnDropout cudnn_dropout_desc_;
CudnnFilterDescriptor filter_desc_;

// filter_desc & rnn_desc could be changed in Compute
mutable RNNDescriptors rnn_descriptors_;

IAllocatorUniquePtr<void> w_data_cache_;
bool weight_cached_;
IAllocatorUniquePtr<void> state_buffer_;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/rnn/gru.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ template <typename T>
class GRU final : public CudnnRnnBase<T> {
public:
GRU(const OpKernelInfo& info) : CudnnRnnBase<T>(info) {
CudnnRnnBase<T>::rnn_mode_ = CUDNN_GRU;
CudnnRnnBase<T>::SetRNNMode(CUDNN_GRU);
CudnnRnnBase<T>::SetCudnnRnnDesc();

// ONNX W layout is Wzrh, WBzrh, mapping to RNNLinLayerMatrixParams the linLayerID is 1, 0, 2
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/rnn/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LSTM final : public CudnnRnnBase<T> {

public:
LSTM(const OpKernelInfo& info) : CudnnRnnBase<T>(info) {
CudnnRnnBase<T>::rnn_mode_ = CUDNN_LSTM;
CudnnRnnBase<T>::SetRNNMode(CUDNN_LSTM);
CudnnRnnBase<T>::SetCudnnRnnDesc();

// ONNX W layout is W[iofc], WB[iofc], mapping to RNNLinLayerMatrixParams the linLayerID is 0, 3, 1, 2
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/rnn/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class RNN final : public CudnnRnnBase<T> {
std::vector<std::string> activations_;
ORT_ENFORCE(info.GetAttrs("activations", activations_).IsOK());
if (activations_[0] == "Relu")
CudnnRnnBase<T>::rnn_mode_ = CUDNN_RNN_RELU;
CudnnRnnBase<T>::SetRNNMode(CUDNN_RNN_RELU);
else if (activations_[0] == "Tanh")
CudnnRnnBase<T>::rnn_mode_ = CUDNN_RNN_TANH;
CudnnRnnBase<T>::SetRNNMode(CUDNN_RNN_TANH);

CudnnRnnBase<T>::SetCudnnRnnDesc();

Expand Down