From 94bda2432b9d5d228a7a6ff6ef92c77f7798dc89 Mon Sep 17 00:00:00 2001 From: Attila Dusnoki <126579622+attila-dusnoki-htec@users.noreply.github.com> Date: Tue, 17 Oct 2023 22:07:49 +0200 Subject: [PATCH] Add axes (optional) input to Pad (#2178) --- src/onnx/parse_pad.cpp | 147 ++++++++++---- test/onnx/gen_onnx.py | 182 ++++++++++++++++++ test/onnx/onnx_test.cpp | 82 ++++++++ test/onnx/pad_4arg_axes_test.onnx | Bin 0 -> 305 bytes .../pad_4arg_invalid_axes_error_test.onnx | Bin 0 -> 334 bytes test/onnx/pad_4arg_neg_axes_test.onnx | Bin 0 -> 331 bytes .../pad_asym_invalid_pads_error_test.onnx | Bin 0 -> 171 bytes test/onnx/pad_asym_test.onnx | Bin 0 -> 136 bytes test/onnx/pad_reflect_with_axes_test.onnx | 18 ++ 9 files changed, 396 insertions(+), 33 deletions(-) create mode 100644 test/onnx/pad_4arg_axes_test.onnx create mode 100644 test/onnx/pad_4arg_invalid_axes_error_test.onnx create mode 100644 test/onnx/pad_4arg_neg_axes_test.onnx create mode 100644 test/onnx/pad_asym_invalid_pads_error_test.onnx create mode 100644 test/onnx/pad_asym_test.onnx create mode 100644 test/onnx/pad_reflect_with_axes_test.onnx diff --git a/src/onnx/parse_pad.cpp b/src/onnx/parse_pad.cpp index 5f425211c66..a654ca06b59 100644 --- a/src/onnx/parse_pad.cpp +++ b/src/onnx/parse_pad.cpp @@ -115,34 +115,9 @@ struct parse_pad : op_parser { std::vector operators() const { return {{"Pad"}}; } - instruction_ref parse(const op_desc& /*opd*/, - const onnx_parser& parser, - onnx_parser::node_info info, - std::vector args) const + std::string parse_mode(const onnx_parser::node_info& info, + const std::vector& args) const { - std::vector pads{}; - if(args.size() >= 2) - { - auto pad_arg = args.at(1)->eval(); - check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant"); - pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); }); - } - else if(contains(info.attributes, "pads")) - { - auto&& pad_vals = info.attributes["pads"].ints(); - pads = std::vector(pad_vals.begin(), pad_vals.end()); - } - else - { - MIGRAPHX_THROW("PARSE_PAD: pad must be available"); - } - - // check if padding is actually being done (at least one value is nonzero) - if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) - { - return info.add_instruction(make_op("identity"), args.front()); - } - if(contains(info.attributes, "mode")) { auto mode = info.attributes.at("mode").s(); @@ -152,28 +127,59 @@ struct parse_pad : op_parser { MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported"); } - return reflect_pad(info, pads, args.front()); } - if(mode != "constant") + else if(mode != "constant") { MIGRAPHX_THROW( "PARSE_PAD: migraphx currently only supports constant and reflect padding"); } + return mode; + } + else + { + // default mode + return "constant"; } + } + std::vector parse_pads(const onnx_parser::node_info& info, + const std::vector& args) const + { + std::vector pads{}; + if(args.size() >= 2) + { + auto pad_arg = args.at(1)->eval(); + check_arg_empty(pad_arg, "PARSE_PAD: `pads` input must be constant"); + pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); }); + } + else if(contains(info.attributes, "pads")) + { + auto&& pad_vals = info.attributes.at("pads").ints(); + pads = std::vector(pad_vals.begin(), pad_vals.end()); + } + else + { + MIGRAPHX_THROW("PARSE_PAD: `pads` must be available"); + } + return pads; + } + + float parse_constant_value(const onnx_parser& parser, + const onnx_parser::node_info& info, + const std::vector& args) const + { float value = 0.0f; - // third input is the value - if(args.size() == 3) + if(args.size() >= 3 and args.at(2)->get_shape().scalar()) { auto val_ins = args.at(2); if(not val_ins->can_eval()) { - MIGRAPHX_THROW("PARSE_PAD: input value must be constant"); + MIGRAPHX_THROW("PARSE_PAD: input `value` must be constant"); } auto val_arg = val_ins->eval(); if(val_arg.get_shape().elements() != 1) { - MIGRAPHX_THROW("PARSE_PAD: value should contain only one element"); + MIGRAPHX_THROW("PARSE_PAD: `value` should contain only one element"); } value = val_arg.at(); } @@ -181,6 +187,81 @@ struct parse_pad : op_parser { value = parser.parse_value(info.attributes.at("value")).at(); } + return value; + } + + std::vector parse_axes(const std::vector& args, + bool is_constant_mode) const + { + std::vector axes{}; + // axes is 3rd or 4th, depending on constant mode + auto pos = is_constant_mode ? 4 : 3; + if(args.size() >= pos) + { + auto axes_arg = args.at(pos - 1)->eval(); + check_arg_empty(axes_arg, "PARSE_PAD: variable `axes` input not supported"); + axes_arg.visit([&](auto v) { axes.assign(v.begin(), v.end()); }); + } + return axes; + } + + std::vector calculate_pads_with_axes(const std::vector& pads, + const std::vector& axes, + size_t input_rank) const + { + size_t num_axes = axes.size(); + if(num_axes * 2 != pads.size()) + { + MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * " + "number of elements of axes"); + } + + std::vector new_pads(input_rank * 2); + for(size_t idx{0}; idx < num_axes; ++idx) + { + // axis can be negative + int64_t axis = axes[idx] < 0 ? input_rank + axes[idx] : axes[idx]; + // pad format is x1_begin, x2_begin, ... , x3_end, x4_end + new_pads[axis] = pads[idx]; + new_pads[axis + input_rank] = pads[idx + num_axes]; + } + return new_pads; + } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + const std::vector& args) const + { + std::vector pads = parse_pads(info, args); + + // check if padding is actually being done (at least one value is nonzero) + if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; })) + { + return info.add_instruction(make_op("identity"), args.front()); + } + + std::string mode = parse_mode(info, args); + bool is_constant_mode = mode == "constant"; + float value = is_constant_mode ? parse_constant_value(parser, info, args) : 0.0f; + std::vector axes = parse_axes(args, is_constant_mode); + size_t input_rank = args.front()->get_shape().ndim(); + + if(not axes.empty()) + { + pads = calculate_pads_with_axes(pads, axes, input_rank); + } + + if(pads.size() != input_rank * 2) + { + MIGRAPHX_THROW("PARSE_PAD: number of elements of pads should be equal to 2 * " + "input rank"); + } + + if(mode == "reflect") + { + return reflect_pad(info, pads, args.front()); + } return info.add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}), args.front()); diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 577eb99f409..de4493cc6fe 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -5107,6 +5107,32 @@ def pad_test(): return ([node], [x], [y]) +@onnx_test() +def pad_asym_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node('Pad', + inputs=['0'], + pads=[0, 1, 0, 3, 0, 2, 0, 4], + outputs=['1']) + + return ([node], [x], [y]) + + +@onnx_test() +def pad_asym_invalid_pads_error_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node('Pad', + inputs=['0'], + pads=[0, 1, 0, 3, 0, 2], + outputs=['1']) + + return ([node], [x], [y]) + + @onnx_test() def pad_3arg_test(): values = np.array([1]) @@ -5139,6 +5165,129 @@ def pad_3arg_test(): return ([arg_val, arg_pad, node], [x], [y]) +@onnx_test() +def pad_4arg_axes_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 3, 2, 4]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([1, 3]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node( + 'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1']) + + return ([arg_axes, arg_val, arg_pad, node], [x], [y]) + + +@onnx_test() +def pad_4arg_invalid_axes_error_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 3, 2, 4]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([1, 2, 3]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node( + 'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1']) + + return ([arg_axes, arg_val, arg_pad, node], [x], [y]) + + +@onnx_test() +def pad_4arg_neg_axes_test(): + values = np.array([1]) + val_tensor = helper.make_tensor(name='val', + data_type=TensorProto.FLOAT, + dims=values.reshape(()).shape, + vals=values.astype(float)) + arg_val = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_val'], + value=val_tensor) + + sizes = np.array([1, 3, 2, 4]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([-3, -1]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4, 5]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 6, 4, 12]) + + node = onnx.helper.make_node( + 'Pad', inputs=['0', 'arg_pad', 'arg_val', 'arg_axes'], outputs=['1']) + + return ([arg_axes, arg_val, arg_pad, node], [x], [y]) + + @onnx_test() def pad_reflect_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) @@ -5162,6 +5311,39 @@ def pad_reflect_test(): return ([arg_pad, node], [x], [y]) +@onnx_test() +def pad_reflect_with_axes_test(): + x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 2]) + y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [2, 5]) + + sizes = np.array([2, 1]) + pad_tensor = helper.make_tensor(name='pad_size', + data_type=TensorProto.INT32, + dims=sizes.shape, + vals=sizes.astype(int)) + arg_pad = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_pad'], + value=pad_tensor) + + axes = np.array([1]) + axes_tensor = helper.make_tensor(name='pad_axes', + data_type=TensorProto.INT32, + dims=axes.shape, + vals=axes.astype(int)) + arg_axes = onnx.helper.make_node('Constant', + inputs=[], + outputs=['arg_axes'], + value=axes_tensor) + + node = onnx.helper.make_node('Pad', + mode='reflect', + inputs=['0', 'arg_pad', 'arg_axes'], + outputs=['1']) + + return ([arg_axes, arg_pad, node], [x], [y]) + + @onnx_test() def pad_reflect_multiaxis_test(): x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 415a612d0e0..fa006131050 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -4958,6 +4958,22 @@ TEST_CASE(pad_test) EXPECT(p == prog); } +TEST_CASE(pad_asym_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}); + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}}), l0); + auto prog = optimize_onnx("pad_asym_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_asym_invalid_pads_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("pad_asym_invalid_pads_error_test.onnx"); })); +} + TEST_CASE(pad_3arg_test) { migraphx::program p; @@ -4974,6 +4990,51 @@ TEST_CASE(pad_3arg_test) EXPECT(p == prog); } +TEST_CASE(pad_4arg_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}); + // axes=[1,3] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 3}}); + // constant_value=1 + mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}}); + // pads=[1,3,2,4] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}}); + auto r = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_4arg_axes_test.onnx"); + + EXPECT(p == prog); +} + +TEST_CASE(pad_4arg_invalid_axes_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("pad_4arg_invalid_axes_error_test.onnx"); })); +} + +TEST_CASE(pad_4arg_neg_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}); + // axes=[-3,-1] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {-3, -1}}); + // constant_value=1 + mm->add_literal({migraphx::shape{migraphx::shape::float_type}, {1.0f}}); + // pads=[1,3,2,4] + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 2, 4}}); + auto r = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 1, 0, 3, 0, 2, 0, 4}}, {"value", 1.0f}}), l0); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_4arg_neg_axes_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(pad_attr_dyn_test) { migraphx::program p; @@ -5032,6 +5093,27 @@ TEST_CASE(pad_reflect_test) EXPECT(p == prog); } +TEST_CASE(pad_reflect_with_axes_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {1}}, {1}}); + mm->add_literal({migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 1}}); + auto l1 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 1}}, {"ends", {2, 2}}}), l0); + auto l2 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0); + auto l3 = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 0}}, {"ends", {2, 1}}}), l0); + auto r = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), l2, l1, l0, l3); + mm->add_return({r}); + + auto prog = migraphx::parse_onnx("pad_reflect_with_axes_test.onnx"); + + EXPECT(p == prog); +} + TEST_CASE(pad_reflect_multiaxis_test) { migraphx::program p; diff --git a/test/onnx/pad_4arg_axes_test.onnx b/test/onnx/pad_4arg_axes_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a9681f4a8dc78f3dd5da8202f59b493f8f95dd28 GIT binary patch literal 305 zcmd(Oo=y1EJ}}0tVk`6FG(#fv6{lfWg^4@6I0@F&d)0@Nz5zJlH+16OUx-v z)e_=h5@6F}Vq|vW0O<#5UBJk~r61YwdeU+1|yqI@EoJ_gWNs$Zo#0)@0zyDbv}gy-ia4;5ehsN%bf!U{zM({jdT zSL0)2j;8#z1j+INh5U1}L~%&gI{p9G%>b%ok9Qf8f@6k21VlmgEtv6-*?61n1dHGa DY&KGU literal 0 HcmV?d00001 diff --git a/test/onnx/pad_4arg_neg_axes_test.onnx b/test/onnx/pad_4arg_neg_axes_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..69f8efd74b369512e3b30c4b732cb3374c6152e7 GIT binary patch literal 331 zcmdk{uFQ?k%qvUG$xMj{3KYku78T_e#h0WOmsmA0aw%~!8VE5ODlrEn zrf7+Bv4E5~FgP#*F*6V|Ens9PY*3Um$P_Uy5e`Nn0WKyEMj&PeViq7~O#(_9qDitr KC3&2f1cU*tDJ3-k literal 0 HcmV?d00001 diff --git a/test/onnx/pad_asym_test.onnx b/test/onnx/pad_asym_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ffc3dc17f98e4ee0eaaba0b43fee0120e60fee28 GIT binary patch literal 136 zcmdk{uFQ=uNi8n1D&$h*Vl)t9G*n^^NKDa^uWxAZ=n?A{>lD0$fZSj6lo`#4JF}ngo literal 0 HcmV?d00001 diff --git a/test/onnx/pad_reflect_with_axes_test.onnx b/test/onnx/pad_reflect_with_axes_test.onnx new file mode 100644 index 00000000000..1556ed708f9 --- /dev/null +++ b/test/onnx/pad_reflect_with_axes_test.onnx @@ -0,0 +1,18 @@ + pad_reflect_with_axes_test:ä +3arg_axes"Constant* +value**Bpad_axes  +3arg_pad"Constant* +value**Bpad_size  +2 +0 +arg_pad +arg_axes1"Pad* +mode"reflect pad_reflect_with_axes_testZ +0 +  + +b +1 +  + +B \ No newline at end of file