diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 910f828c01776..1322b016befd7 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -83,7 +83,8 @@ Status CudnnDataTensor::Set(cudnnDataType_t dataType, const int32_t* seq_lengths) { ORT_RETURN_IF_ERROR(CreateTensorIfNeeded()); - cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED; + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED; float padding_fill = 0.0f; CUDNN_RETURN_IF_ERROR(cudnnSetRNNDataDescriptor(tensor_, dataType, layout, static_cast(max_seq_length), diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 5a7b0ffbf2c37..4db30f934e6d2 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -220,7 +220,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { x_data = x_reversed_data.get(); } - auto byte_size = X->DataType()->Size(); const T* hx_data = (initial_h == nullptr) ? nullptr : initial_h->template Data(); const T* cx_data = (initial_c == nullptr) ? nullptr : initial_c->template Data(); T* y_h_data = (Y_h == nullptr) ? nullptr : Y_h->template MutableData(); @@ -234,10 +233,12 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { y_alloc_data = GetScratchBuffer(output_size); y_data = y_alloc_data.get(); } - // Cudnn library doesn't guarantee the data beyond the shorter sequence will be initialized to 0, so we need to do it manually. - cudaMemset(y_data, 0, output_size * byte_size); + const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->template Data(); + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc_, CUDNN_RNN_PADDED_IO_ENABLED)); + size_t workspace_bytes; CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(CudnnHandle(), rnn_desc_, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); auto workspace_cuda = GetScratchBuffer(workspace_bytes); @@ -288,6 +289,10 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { nullptr, nullptr, nullptr, nullptr, workspace_cuda.get(), workspace_bytes)); + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. + if (nullptr == Y) { + return Status::OK(); + } } IAllocatorUniquePtr y_reorganized_data; diff --git a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc index eaa0cc649f91a..7c2d23de5ca4c 100644 --- a/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/deep_cpu_lstm_op_test.cc @@ -1090,6 +1090,37 @@ TEST(LSTMTest, ONNXRuntime_TestLSTMSequenceLengthShorterThanInputSequenceLength) LstmOpContext2x1x2x2 context(direction); context.RunTest(X_data, batch_size, seq_len, &initial_h, &initial_c, Y_data, Y_h_data, {}, &sequence_length); } + +TEST(LSTMTest, ONNXRuntime_TestLSTMSequenceLengthShorterThanInputSequenceLengthNoP) { + const int seq_len = 2; + const int batch_size = 1; + + std::vector X_data = {-0.455351f, -0.276391f, + -0.185934f, -0.269585f}; + + std::vector sequence_length = {1}; + + std::vector initial_h = {0.0f, 0.0f, + -0.0306872f, 0.028035f}; + + std::vector initial_c = {0.0f, 0.0f, + -0.07243599f, 0.0467052f}; + + std::vector Y_data = {0.0415416f, 0.0196912f, + 0.0295027f, 0.0334400f, + + 0.0f, 0.0f, + 0.0f, 0.0f}; + + std::vector Y_h_data = {0.0415416f, 0.0196912f, + 0.0295027f, 0.0334400f}; + + std::string direction = "bidirectional"; + + LstmOpContext2x1x2x2 context(direction); + // CUDA implementation doesn't support peephole + context.RunTest(X_data, batch_size, seq_len, &initial_h, &initial_c, Y_data, Y_h_data, {}, &sequence_length, false); +} #endif // USE_NGRAPH } // namespace test