Skip to content

Commit

Permalink
Bug fix for shape of optional output in Dropout op (#1507)
Browse files Browse the repository at this point in the history
* Bug fix for shape of optional output in Dropout op

* Exclude new test from NGraph EP

* Account for the fact that mask could be of different type in different opset variants of the op

* Make accompanying Cuda changes

* Fix build break

* Exclude Opset 7 test for tensorRT EP

* PR comments
  • Loading branch information
hariharans29 authored Aug 1, 2019
1 parent 57e2482 commit 465b30e
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 6 deletions.
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/cpu/tensor/identity_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,18 @@ class IdentityOp final : public OpKernel {
}

if (is_dropout) {
context->Output(1, std::vector<int64_t>());
Tensor* mask = context->Output(1, shape);
// a 'nullptr' returned would make it an unused optional output
if (mask != nullptr) {
// Opset 7 differs with Opset 10 in that the type of the 'mask'
// output is tied with the type of the input in Opset 7 whereas
// the type of 'mask' in Opset 10 is 'bool' always
// so we have a common solution
void* mask_data = mask->MutableDataRaw();
// In 'test'/'inference' mode, there are no input values dropped out
// so fill the buffer with 0/false
memset(mask_data, 0, mask->SizeInBytes());
}
}

return Status::OK();
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Un
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, Flatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Squeeze);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, Dropout);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Gather);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm);
Expand Down Expand Up @@ -515,6 +515,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Shrink);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Shrink);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout);

static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
Expand All @@ -525,7 +526,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Squeeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Identity)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 9, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm)>,
Expand Down Expand Up @@ -840,6 +841,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Shrink)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Shrink)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
19 changes: 17 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/identity_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@

namespace onnxruntime {
namespace cuda {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Dropout,
kOnnxDomain,
7, 9,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()})
.Alias(0, 0),
IdentityOp<true>);

ONNX_OPERATOR_KERNEL_EX(
Dropout,
kOnnxDomain,
7,
10,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(), DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<double>()})
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()})
.TypeConstraint("T1", DataTypeImpl::GetTensorType<bool>())
.Alias(0, 0),
IdentityOp<true>);

Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/identity_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ class IdentityOp final : public CudaKernel {
}

if (is_dropout) {
context->Output(1, std::vector<int64_t>());
Tensor* mask = context->Output(1, shape);
// a 'nullptr' returned would make it an unused optional output
if (mask != nullptr) {
// Opset 7 differs with Opset 10 in that the type of the 'mask'
// output is tied with the type of the input in Opset 7 whereas
// the type of 'mask' in Opset 10 is 'bool' always
// so we have a common solution
void* mask_data = mask->MutableDataRaw();
// In 'test'/'inference' mode, there are no input values dropped out
// so fill the buffer with 0/false
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_data, 0, mask->SizeInBytes()));
}
}

return Status::OK();
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/test/providers/cpu/nn/dropout_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,29 @@ TEST(Dropout, Opset10) {
test.Run();
}

TEST(Dropout, WithOptionalOutputOpset10) {
OpTester test("Dropout", 10, kOnnxDomain);
std::vector<int64_t> dims{2, 2};
test.AddInput<float>("X", dims, {1.0f, 2.0f, 3.0f, 5.0f});
test.AddOutput<float>("Y", dims, {1.0f, 2.0f, 3.0f, 5.0f});
test.AddOutput<bool>("mask", dims, {false, false, false, false});
// The NGraph execution provider doesn't seem to support 'Dropout' with optional mask output
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider});
}

TEST(Dropout, WithOptionalOutputOpset7) {
// Opset 7 differs with Opset 10 in that the type of the 'mask'
// output is tied with the type of the input in Opset 7 whereas
// the type of 'mask' in Opset 10 is 'bool' always
OpTester test("Dropout", 7, kOnnxDomain);
std::vector<int64_t> dims{2, 2};
test.AddInput<float>("X", dims, {1.0f, 2.0f, 3.0f, 5.0f});
test.AddOutput<float>("Y", dims, {1.0f, 2.0f, 3.0f, 5.0f});
test.AddOutput<float>("mask", dims, {0.0f, 0.0f, 0.0f, 0.0f});
// The NGraph execution provider doesn't seem to support 'Dropout' with optional mask output
// The TensorRT execution provider doesn't seem to support 'Dropout' with non-boolean mask output
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider});
}

} // namespace test
} // namespace onnxruntime

0 comments on commit 465b30e

Please sign in to comment.