Skip to content

Commit

Permalink
Support non-default negative axis value and intuitive data type combi…
Browse files Browse the repository at this point in the history
…nation for OneHot op (#1317) (#1732)

* Handle nondefault negative axis value

* Support more intuitive data types for this op
  • Loading branch information
hariharans29 authored Sep 4, 2019
1 parent 71e0c44 commit 9c1ce29
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_string_int64_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh);
Expand Down Expand Up @@ -480,6 +481,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_string_int64_t, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh)>,
Expand Down
21 changes: 17 additions & 4 deletions onnxruntime/core/providers/cpu/tensor/onehot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ REG_ONE_HOT_OP(int64_t, int64_t, int64_t);
REG_ONE_HOT_OP(float, int64_t, int64_t);
REG_ONE_HOT_OP(int64_t, string, int64_t);
REG_ONE_HOT_OP(float, string, int64_t);
REG_ONE_HOT_OP(int64_t, float, int64_t);
REG_ONE_HOT_OP(float, float, float); // added this to satisfy onnx model tests
REG_ONE_HOT_OP(int64_t, int32_t, float); // added this to satisfy onnx model tests

Expand Down Expand Up @@ -120,16 +121,28 @@ Status OneHotOp<in_type, out_type, depth_type>::Compute(OpKernelContext* p_op_ke
const auto& indices_dims = indices_shape.GetDims();
const auto indices_num_dims = indices_shape.NumDimensions();
std::vector<int64_t> output_shape(indices_shape.GetDims());
output_shape.insert(axis_ == -1 ? output_shape.end() : output_shape.begin() + axis_,
depth_val);

// output rank is always 1 more than the input rank as a new dimension is added to the input shape
const auto output_rank = static_cast<int64_t>(indices_num_dims + 1);
if (axis_ >= output_rank || axis_ < -output_rank) {
std::ostringstream oss;
oss << "'axis' attribute must have a value in the range [" << -output_rank
<< "," << indices_num_dims << "]";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, oss.str());
}

auto true_axis = axis_;
if (true_axis < 0)
true_axis += output_rank;

output_shape.insert(output_shape.begin() + true_axis, depth_val);

// allocate output
const auto* values_data = values->Data<out_type>();
Tensor* output = p_op_kernel_context->Output(0, TensorShape(output_shape));

const int64_t axis = (axis_ == -1) ? indices_num_dims : axis_;
int64_t prefix_dim_size = 1;
for (int64_t i = 0; i < axis; ++i) {
for (int64_t i = 0; i < true_axis; ++i) {
prefix_dim_size *= indices_dims[i];
}
const int64_t suffix_dim_size = indices_shape.Size() / prefix_dim_size;
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/providers/cpu/tensor/onehot.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ class OneHotOp final : public OpKernel {
explicit OneHotOp(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
int64_t tmp_axis;
if (op_kernel_info.GetAttr<int64_t>("axis", &tmp_axis).IsOK()) {
if (tmp_axis < -1) { // as per spec it can be -1 or more
ORT_THROW("Value of axis is < -1");
}
axis_ = tmp_axis;
}
}
Expand Down
34 changes: 34 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ TEST(OneHotOpTest, DefaultAxis_int64_int32_float /*indices, output, depth*/) {
test.Run();
}

TEST(OneHotOpTest, DefaultAxis_int64_float_int64 /*indices, output, depth*/) {
OpTester test("OneHot", 9);
test.AddInput<int64_t>("indices", {2, 3}, {1, 9, 8, 2, 4, 6});
test.AddInput<int64_t>("depth", {1}, {10});
test.AddInput<float>("values", {2}, {0, 1});
test.AddOutput<float>("output", {2, 3, 10}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0,});
test.Run();
}

TEST(OneHotOpTest, Axis_0) {
OpTester test("OneHot", 9);
int64_t axis = 0;
Expand Down Expand Up @@ -117,6 +131,26 @@ TEST(OneHotOpTest, Axis_2) {
test.Run();
}

TEST(OneHotOpTest, Axis_Negative_NonDefault) {
OpTester test("OneHot", 9);
int64_t axis = -3;
test.AddAttribute("axis", axis);
test.AddInput<int64_t>("indices", {2, 3}, {1, 9, 8, 2, 4, 6});
test.AddInput<int64_t>("depth", {1}, {10});
test.AddInput<int64_t>("values", {2}, {0, 1});
test.AddOutput<int64_t>("output", {10, 2, 3}, { 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0,});
test.Run();
}

TEST(OneHotOpTest, FloatInt64) {
OpTester test("OneHot", 9);
test.AddInput<float>("indices", {2, 3}, {1.f, 9.f, 8.f, 2.f, 4.f, 6.f});
Expand Down

0 comments on commit 9c1ce29

Please sign in to comment.