diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 06f9c395dd627..4088f137b5676 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -218,6 +218,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_string_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); @@ -480,6 +481,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/onehot.cc b/onnxruntime/core/providers/cpu/tensor/onehot.cc index e35e1b6b72729..1dfbaaf37640f 100644 --- a/onnxruntime/core/providers/cpu/tensor/onehot.cc +++ b/onnxruntime/core/providers/cpu/tensor/onehot.cc @@ -45,6 +45,7 @@ REG_ONE_HOT_OP(int64_t, int64_t, int64_t); REG_ONE_HOT_OP(float, int64_t, int64_t); REG_ONE_HOT_OP(int64_t, string, int64_t); REG_ONE_HOT_OP(float, string, int64_t); +REG_ONE_HOT_OP(int64_t, float, int64_t); REG_ONE_HOT_OP(float, float, float); // added this to satisfy onnx model tests REG_ONE_HOT_OP(int64_t, int32_t, float); // added this to satisfy onnx model tests @@ -120,16 +121,28 @@ Status OneHotOp::Compute(OpKernelContext* p_op_ke const auto& indices_dims = indices_shape.GetDims(); const auto indices_num_dims = indices_shape.NumDimensions(); std::vector output_shape(indices_shape.GetDims()); - output_shape.insert(axis_ == -1 ? output_shape.end() : output_shape.begin() + axis_, - depth_val); + + // output rank is always 1 more than the input rank as a new dimension is added to the input shape + const auto output_rank = static_cast(indices_num_dims + 1); + if (axis_ >= output_rank || axis_ < -output_rank) { + std::ostringstream oss; + oss << "'axis' attribute must have a value in the range [" << -output_rank + << "," << indices_num_dims << "]"; + return Status(ONNXRUNTIME, INVALID_ARGUMENT, oss.str()); + } + + auto true_axis = axis_; + if (true_axis < 0) + true_axis += output_rank; + + output_shape.insert(output_shape.begin() + true_axis, depth_val); // allocate output const auto* values_data = values->Data(); Tensor* output = p_op_kernel_context->Output(0, TensorShape(output_shape)); - const int64_t axis = (axis_ == -1) ? indices_num_dims : axis_; int64_t prefix_dim_size = 1; - for (int64_t i = 0; i < axis; ++i) { + for (int64_t i = 0; i < true_axis; ++i) { prefix_dim_size *= indices_dims[i]; } const int64_t suffix_dim_size = indices_shape.Size() / prefix_dim_size; diff --git a/onnxruntime/core/providers/cpu/tensor/onehot.h b/onnxruntime/core/providers/cpu/tensor/onehot.h index a05731b777929..495342d99b3e7 100644 --- a/onnxruntime/core/providers/cpu/tensor/onehot.h +++ b/onnxruntime/core/providers/cpu/tensor/onehot.h @@ -14,9 +14,6 @@ class OneHotOp final : public OpKernel { explicit OneHotOp(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { int64_t tmp_axis; if (op_kernel_info.GetAttr("axis", &tmp_axis).IsOK()) { - if (tmp_axis < -1) { // as per spec it can be -1 or more - ORT_THROW("Value of axis is < -1"); - } axis_ = tmp_axis; } } diff --git a/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc b/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc index 152cb3bcb9345..c816c7f7b6661 100644 --- a/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc @@ -51,6 +51,20 @@ TEST(OneHotOpTest, DefaultAxis_int64_int32_float /*indices, output, depth*/) { test.Run(); } +TEST(OneHotOpTest, DefaultAxis_int64_float_int64 /*indices, output, depth*/) { + OpTester test("OneHot", 9); + test.AddInput("indices", {2, 3}, {1, 9, 8, 2, 4, 6}); + test.AddInput("depth", {1}, {10}); + test.AddInput("values", {2}, {0, 1}); + test.AddOutput("output", {2, 3, 10}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,}); + test.Run(); +} + TEST(OneHotOpTest, Axis_0) { OpTester test("OneHot", 9); int64_t axis = 0; @@ -117,6 +131,26 @@ TEST(OneHotOpTest, Axis_2) { test.Run(); } +TEST(OneHotOpTest, Axis_Negative_NonDefault) { + OpTester test("OneHot", 9); + int64_t axis = -3; + test.AddAttribute("axis", axis); + test.AddInput("indices", {2, 3}, {1, 9, 8, 2, 4, 6}); + test.AddInput("depth", {1}, {10}); + test.AddInput("values", {2}, {0, 1}); + test.AddOutput("output", {10, 2, 3}, { 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 0, + 0, 0, 1, 0, 0, 0, + 0, 1, 0, 0, 0, 0,}); + test.Run(); +} + TEST(OneHotOpTest, FloatInt64) { OpTester test("OneHot", 9); test.AddInput("indices", {2, 3}, {1.f, 9.f, 8.f, 2.f, 4.f, 6.f});