diff --git a/onnxruntime/core/providers/cpu/tensor/squeeze.h b/onnxruntime/core/providers/cpu/tensor/squeeze.h index 2387832018ae4..f6489e5cf2e03 100644 --- a/onnxruntime/core/providers/cpu/tensor/squeeze.h +++ b/onnxruntime/core/providers/cpu/tensor/squeeze.h @@ -13,13 +13,15 @@ class SqueezeBase { protected: explicit SqueezeBase(const OpKernelInfo& info) { std::vector axes; - Status status = info.GetAttrs("axes", axes); - ORT_ENFORCE(status.IsOK(), "Attribute axes is not set."); + // Parse attribute 'axes' + Status status = info.GetAttrs("axes", axes); - // Handle out of order and repeating dims. - std::sort(axes.begin(), axes.end()); - axes.erase(std::unique(axes.begin(), axes.end()), axes.end()); - axes_ = axes; + // Handle out of order and repeating dims when 'axes' exists. + if (status.IsOK()) { + std::sort(axes.begin(), axes.end()); + axes.erase(std::unique(axes.begin(), axes.end()), axes.end()); + axes_ = axes; + } } static std::vector ComputeOutputShape( @@ -28,7 +30,8 @@ class SqueezeBase { size_t j = 0; std::vector output_shape; for (size_t i = 0; i < input_shape.NumDimensions(); ++i) { - if (j < axes.NumDimensions() && axes[j] == static_cast(i)) { + if ((j < axes.NumDimensions() && axes[j] == static_cast(i)) || + (axes.NumDimensions() == 0 && input_shape[i] == 1)) { ORT_ENFORCE(input_shape[i] == 1, "Dimension of input ", i, " must be 1 instead of ", input_shape[i], ". shape=", input_shape); ++j; @@ -59,4 +62,4 @@ class Squeeze final : public OpKernel, public SqueezeBase { } }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/cpu/tensor/squeeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/squeeze_op_test.cc index 979512e8bf739..4287d1369bd65 100644 --- a/onnxruntime/test/providers/cpu/tensor/squeeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/squeeze_op_test.cc @@ -18,6 +18,23 @@ TEST(SqueezeOpTest, Squeeze_1) { test.Run(); } +TEST(SqueezeOpTest, Squeeze_Empty_Axes_1) { + OpTester test("Squeeze"); + test.AddInput("data", {1, 1, 4, 1}, std::vector(4, 1.0f)); + test.AddOutput("squeezed", {4}, std::vector(4, 1.0f)); + // TensorRT doesn't seem to support missing 'axes' + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +TEST(SqueezeOpTest, Squeeze_Empty_Axes_2) { + OpTester test("Squeeze"); + // nothing to "squeeze" out in the input shape + test.AddInput("data", {2, 4}, std::vector(8, 1.0f)); + test.AddOutput("squeezed", {2, 4}, std::vector(8, 1.0f)); + // TensorRT doesn't seem to support missing 'axes' + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(SqueezeOpTest, Squeeze_1_int32) { OpTester test("Squeeze"); test.AddAttribute("axes", std::vector{0});