From dacd3f7d66f491e54f9d085167c806758c00a8cd Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 13 Apr 2022 19:19:37 -0500 Subject: [PATCH 1/4] Implementation --- src/onnx/parse_reversesequence.cpp | 130 +++++++++++++++++ test/onnx/gen_onnx.py | 136 ++++++++++++++++++ test/onnx/onnx_test.cpp | 127 ++++++++++++++++ test/onnx/reversesequence_4D_test.onnx | Bin 0 -> 203 bytes .../reversesequence_batch_axis_err_test.onnx | Bin 0 -> 223 bytes test/onnx/reversesequence_batch_test.onnx | Bin 0 -> 248 bytes test/onnx/reversesequence_rank_err_test.onnx | 12 ++ .../reversesequence_same_axis_err_test.onnx | 16 +++ ...sequence_sequence_lens_shape_err_test.onnx | 12 ++ .../reversesequence_time_axis_err_test.onnx | Bin 0 -> 244 bytes test/onnx/reversesequence_time_test.onnx | Bin 0 -> 195 bytes test/onnx/verify_onnx.cpp | 63 ++++++++ test/py/onnx_backend_test.py | 1 + 13 files changed, 497 insertions(+) create mode 100644 src/onnx/parse_reversesequence.cpp create mode 100644 test/onnx/reversesequence_4D_test.onnx create mode 100644 test/onnx/reversesequence_batch_axis_err_test.onnx create mode 100644 test/onnx/reversesequence_batch_test.onnx create mode 100644 test/onnx/reversesequence_rank_err_test.onnx create mode 100644 test/onnx/reversesequence_same_axis_err_test.onnx create mode 100644 test/onnx/reversesequence_sequence_lens_shape_err_test.onnx create mode 100644 test/onnx/reversesequence_time_axis_err_test.onnx create mode 100644 test/onnx/reversesequence_time_test.onnx diff --git a/src/onnx/parse_reversesequence.cpp b/src/onnx/parse_reversesequence.cpp new file mode 100644 index 00000000000..cb93b9dd2d1 --- /dev/null +++ b/src/onnx/parse_reversesequence.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +//! Parser for ReverseSequence ONNX operator. +/*! + Reverses the data along the time axis for the batches along the batch axis. + The sequence lengths can be given to reverse up to the given length for each batch, keeping the + rest of the sequence in the original order. Variable sequence_lens is not supported in this + version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The + batch axis and time axis must be [0, 1] and not the same. +*/ +struct parse_reversesequence : op_parser +{ + std::vector operators() const { return {{"ReverseSequence"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int batch_axis = 1; + if(contains(info.attributes, "batch_axis")) + { + batch_axis = info.attributes.at("batch_axis").i(); + } + if(batch_axis != 0 and batch_axis != 1) + { + MIGRAPHX_THROW("PARSE_REVERSESEQUENCE: batch axis not 0 or 1"); + } + + int time_axis = 0; + if(contains(info.attributes, "time_axis")) + { + time_axis = info.attributes.at("time_axis").i(); + } + if(time_axis != 0 and time_axis != 1) + { + MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1"); + } + + if(time_axis == batch_axis) + { + MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same"); + } + + auto input = args[0]; + auto input_lens = input->get_shape().lens(); + if(input_lens.size() < 2) + { + MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2"); + } + + std::vector sequence_lens; + if(args.size() == 2) + { + migraphx::argument seq_lens_arg = args.back()->eval(); + check_arg_empty(seq_lens_arg, + "PARSE_REVERSESEQUENCE: cannot handle variable sequence_lens"); + seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); }); + } + else if(contains(info.attributes, "sequence_lens")) + { + literal s = parser.parse_value(info.attributes.at("sequence_lens")); + s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); }); + } + auto batch_size = input_lens[batch_axis]; + auto time_size = input_lens[time_axis]; + + // this condition may still work even if the inputted shape was incorrect + if(sequence_lens.size() != batch_size) + { + MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape"); + } + + instruction_ref ret; + + for(int b = 0; b < batch_size; ++b) + { + instruction_ref s0; + if(sequence_lens[b] > 1) + { + s0 = info.add_instruction(make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, 0}}, + {"ends", {b + 1, sequence_lens[b]}}}), + input); + s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0); + + // if reversed less than whole batch, concat rest of batch + if(sequence_lens[b] < time_size) + { + auto s1 = info.add_instruction(make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, sequence_lens[b]}}, + {"ends", {b + 1, time_size}}}), + input); + s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1); + } + } + else + { // cases where nothing changes + s0 = info.add_instruction(make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, 0}}, + {"ends", {b + 1, time_size}}}), + input); + } + if(b == 0) + { + ret = s0; + } + else + { + ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0); + } + } + return ret; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index ad4980c949e..5aa98dacfa6 100755 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -4345,6 +4345,142 @@ def resize_upsample_pc_test(): return ([node], [X], [Y], [scale_tensor]) +@onnx_test +def reversesequence_4D_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 2, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 2, 2]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=0, + batch_axis=1, + sequence_lens=[2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_batch_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + seq_lens = np.array([1, 2, 3, 4]) + seq_lens_tensor = helper.make_tensor( + name="sequence_lens", + data_type=TensorProto.INT64, + dims=seq_lens.shape, + vals=seq_lens.astype(np.int64), + ) + arg_seq_lens = helper.make_node( + "Constant", + inputs=[], + outputs=['arg_seq_lens'], + value=seq_lens_tensor, + ) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x', 'arg_seq_lens'], + outputs=['y'], + time_axis=1, + batch_axis=0, + ) + return ([arg_seq_lens, node], [x], [y]) + + +@onnx_test +def reversesequence_batch_axis_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=0, + batch_axis=2, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_rank_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_sequence_lens_shape_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + sequence_lens=[4, 3, 2], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_same_axis_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x', 'sequence_lens'], + outputs=['y'], + time_axis=1, + batch_axis=1, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_time_axis_err_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4, 2, 3]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4, 2, 3]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x', 'sequence_lens'], + outputs=['y'], + time_axis=3, + batch_axis=0, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + +@onnx_test +def reversesequence_time_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [4, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [4, 4]) + + node = onnx.helper.make_node( + 'ReverseSequence', + inputs=['x'], + outputs=['y'], + time_axis=0, + batch_axis=1, + sequence_lens=[4, 3, 2, 1], + ) + return ([node], [x], [y]) + + @onnx_test def roialign_default_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 4, 7, 8]) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index bc6a3c04e70..5722ab21ad3 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -4173,6 +4173,133 @@ TEST_CASE(resize_upsample_pf_test) EXPECT(p == prog); } +TEST_CASE(reversesequence_batch_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + int batch_axis = 0; + int time_axis = 1; + + migraphx::shape sx{migraphx::shape::float_type, {4, 4}}; + auto input = mm->add_parameter("x", sx); + + std::vector sequence_lens = {1, 2, 3, 4}; + mm->add_literal({{migraphx::shape::int64_type, {4}}, sequence_lens}); + + int batch_size = sx.lens()[batch_axis]; + int time_size = sx.lens()[time_axis]; + + auto ret = mm->add_instruction( + migraphx::make_op( + "slice", + {{"axes", {batch_axis, time_axis}}, {"starts", {0, 0}}, {"ends", {1, time_size}}}), + input); + for(int b = 1; b < batch_size; ++b) + { + auto s0 = mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, 0}}, + {"ends", {b + 1, sequence_lens[b]}}}), + input); + s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0); + if(sequence_lens[b] < time_size) + { + auto s1 = mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, sequence_lens[b]}}, + {"ends", {b + 1, time_size}}}), + input); + s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1); + } + ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); + } + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("reversesequence_batch_test.onnx"); + EXPECT(p == prog); +} + +TEST_CASE(reversesequence_batch_axis_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_batch_axis_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_rank_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_rank_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_sequence_lens_shape_err_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("reversesequence_sequence_lens_shape_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_same_axis_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_same_axis_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_time_axis_err_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("reversesequence_time_axis_err_test.onnx"); })); +} + +TEST_CASE(reversesequence_time_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + int batch_axis = 1; + int time_axis = 0; + + migraphx::shape sx{migraphx::shape::float_type, {4, 4}}; + auto input = mm->add_parameter("x", sx); + + int batch_size = sx.lens()[batch_axis]; + int time_size = sx.lens()[time_axis]; + std::vector sequence_lens = {4, 3, 2, 1}; + + migraphx::instruction_ref ret; + for(int b = 0; b < batch_size - 1; ++b) + { + auto s0 = mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, 0}}, + {"ends", {b + 1, sequence_lens[b]}}}), + input); + s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0); + if(sequence_lens[b] < time_size) + { + auto s1 = mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, sequence_lens[b]}}, + {"ends", {b + 1, time_size}}}), + input); + s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1); + } + if(b == 0) + { + ret = s0; + } + else + { + ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); + } + } + auto s0 = mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {batch_size - 1, 0}}, + {"ends", {batch_size, time_size}}}), + input); + ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); + mm->add_return({ret}); + + auto prog = migraphx::parse_onnx("reversesequence_time_test.onnx"); + EXPECT(p == prog); +} + TEST_CASE(roialign_default_test) { migraphx::shape sx{migraphx::shape::float_type, {10, 4, 7, 8}}; diff --git a/test/onnx/reversesequence_4D_test.onnx b/test/onnx/reversesequence_4D_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7fd4e02f35ed2f2403344366337f95885a142b79 GIT binary patch literal 203 zcmd;J7ZNW@ElVvbPAyI?EKSWzPK`Hli7!blF0oq4$Q8oHSRuq%sl*=yQ4T=dm!uPu051S~8!%G< literal 0 HcmV?d00001 diff --git a/test/onnx/reversesequence_batch_axis_err_test.onnx b/test/onnx/reversesequence_batch_axis_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4c80e8e3679842f71b7c7249bd0aef5de3d38c80 GIT binary patch literal 223 zcmd;J7g8=tElVvbPAyI?EKSWzPK{4WEJ@CYPprr+j!!KriZ4kmF0tCk$Q8lGSRuq% zsl*=yQ6CIduO-OEh0r3ww1AOGONxsZZd^`kUa$3luXY0VOK2NjNbH@B#pUXgf;) literal 0 HcmV?d00001 diff --git a/test/onnx/reversesequence_batch_test.onnx b/test/onnx/reversesequence_batch_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..eaf5f42e5e5382f932bd5a3ea09653d80c63c5a8 GIT binary patch literal 248 zcmd;J7m_MUElVvbPAyI?EKSWzPK{4WEJ@CYFG(#fvAWF2WiP~&Sd<3&ji}#!4mr zAc(EOP+PSGxws$}Cst$@OE4^8WYQAg;w;I`O^pYOFakyJxiv}{Xpayd7Y_%c5C<0% R2MY)%0R=131)Z1#cmewWLvH{8 literal 0 HcmV?d00001 diff --git a/test/onnx/reversesequence_rank_err_test.onnx b/test/onnx/reversesequence_rank_err_test.onnx new file mode 100644 index 00000000000..0aec563af40 --- /dev/null +++ b/test/onnx/reversesequence_rank_err_test.onnx @@ -0,0 +1,12 @@ +reversesequence_rank_err_test:v +3 +xy"ReverseSequence* + sequence_lens@@@@ reversesequence_rank_err_testZ +x + + +b +y + + +B \ No newline at end of file diff --git a/test/onnx/reversesequence_same_axis_err_test.onnx b/test/onnx/reversesequence_same_axis_err_test.onnx new file mode 100644 index 00000000000..cea01ec7d60 --- /dev/null +++ b/test/onnx/reversesequence_same_axis_err_test.onnx @@ -0,0 +1,16 @@ +"reversesequence_same_axis_err_test:· +g +x + sequence_lensy"ReverseSequence* + +batch_axis * + sequence_lens@@@@ * + time_axis "reversesequence_same_axis_err_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/reversesequence_sequence_lens_shape_err_test.onnx b/test/onnx/reversesequence_sequence_lens_shape_err_test.onnx new file mode 100644 index 00000000000..3f08cf6a406 --- /dev/null +++ b/test/onnx/reversesequence_sequence_lens_shape_err_test.onnx @@ -0,0 +1,12 @@ +,reversesequence_sequence_lens_shape_err_test:‹ +1 +xy"ReverseSequence* + sequence_lens@@@ ,reversesequence_sequence_lens_shape_err_testZ +x +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/reversesequence_time_axis_err_test.onnx b/test/onnx/reversesequence_time_axis_err_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..81fd08c44e7fba03fbfae93ba026b014f3e5e0ed GIT binary patch literal 244 zcmd;J7g8!pElVvbPAyI?EKSWzPK__g%uS6?tjH{mPc15nFG(#fu{zGkmCnUj!Nm(# zl#`lQEW}u;#2*CF9Sqg2CCJ592E7#A=yX-RSM!VSqu%`0|babR{}asbM)YYA|1g3Sf1W&oEIxz|G0sx*vGP(c& literal 0 HcmV?d00001 diff --git a/test/onnx/verify_onnx.cpp b/test/onnx/verify_onnx.cpp index 305d4ac5bf8..a93f693d263 100644 --- a/test/onnx/verify_onnx.cpp +++ b/test/onnx/verify_onnx.cpp @@ -698,6 +698,69 @@ TEST_CASE(resize_upsample_pf_test) EXPECT(migraphx::verify_range(result_vector, gold)); } +TEST_CASE(reversesequence_4D_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("reversesequence_4D_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::float_type, {2, 2, 2, 2}}; + std::vector x_data = { + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; + migraphx::parameter_map param_map; + param_map["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(param_map).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(reversesequence_batch_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("reversesequence_batch_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::float_type, {4, 4}}; + std::vector x_data = { + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; + migraphx::parameter_map param_map; + param_map["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(param_map).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + +TEST_CASE(reversesequence_time_verify_test) +{ + migraphx::program p = migraphx::parse_onnx("reversesequence_time_test.onnx"); + p.compile(migraphx::ref::target{}); + + migraphx::shape xs{migraphx::shape::float_type, {4, 4}}; + std::vector x_data = { + 0.0, 4.0, 8.0, 12.0, 1.0, 5.0, 9.0, 13.0, 2.0, 6.0, 10.0, 14.0, 3.0, 7.0, 11.0, 15.0}; + migraphx::parameter_map param_map; + param_map["x"] = migraphx::argument(xs, x_data.data()); + + auto result = p.eval(param_map).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = { + 3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0}; + + EXPECT(migraphx::verify_range(result_vector, gold)); +} + TEST_CASE(selu_test) { migraphx::program p = migraphx::parse_onnx("selu_test.onnx"); diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 09e01cab6fa..125cb1d6277 100755 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -178,6 +178,7 @@ def create_backend_test(testname=None, target_device=None): backend_test.include(r'.*test_reduce.*') backend_test.include(r'.*test_ReLU*') backend_test.include(r'.*test_relu.*') + backend_test.include(r'.*test_reversesequence.*') backend_test.include(r'.*test_RoiAlign*') backend_test.include(r'.*test_roialign.*') backend_test.include(r'.*test_scatter.*') From 3c1df5e9b2fc1bb7f518147d0d3cf6dd88345cbb Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 13 Apr 2022 19:46:19 -0500 Subject: [PATCH 2/4] Disable ONNX test Current implementation cannot handle variable sequence lengths --- test/py/onnx_backend_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index 125cb1d6277..1e4cd76e85b 100755 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -178,7 +178,7 @@ def create_backend_test(testname=None, target_device=None): backend_test.include(r'.*test_reduce.*') backend_test.include(r'.*test_ReLU*') backend_test.include(r'.*test_relu.*') - backend_test.include(r'.*test_reversesequence.*') + #backend_test.include(r'.*test_reversesequence.*') backend_test.include(r'.*test_RoiAlign*') backend_test.include(r'.*test_roialign.*') backend_test.include(r'.*test_scatter.*') From 3b75206b5d72f454c8df35a15865cee558a5de1e Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 20 Apr 2022 14:22:52 -0500 Subject: [PATCH 3/4] Review updates --- src/onnx/parse_reversesequence.cpp | 33 ++++++++----------- test/onnx/onnx_test.cpp | 53 +++++++++++++----------------- 2 files changed, 37 insertions(+), 49 deletions(-) diff --git a/src/onnx/parse_reversesequence.cpp b/src/onnx/parse_reversesequence.cpp index cb93b9dd2d1..f1e9fced7c2 100644 --- a/src/onnx/parse_reversesequence.cpp +++ b/src/onnx/parse_reversesequence.cpp @@ -32,7 +32,7 @@ struct parse_reversesequence : op_parser } if(batch_axis != 0 and batch_axis != 1) { - MIGRAPHX_THROW("PARSE_REVERSESEQUENCE: batch axis not 0 or 1"); + MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1"); } int time_axis = 0; @@ -61,8 +61,7 @@ struct parse_reversesequence : op_parser if(args.size() == 2) { migraphx::argument seq_lens_arg = args.back()->eval(); - check_arg_empty(seq_lens_arg, - "PARSE_REVERSESEQUENCE: cannot handle variable sequence_lens"); + check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens"); seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); }); } else if(contains(info.attributes, "sequence_lens")) @@ -73,7 +72,7 @@ struct parse_reversesequence : op_parser auto batch_size = input_lens[batch_axis]; auto time_size = input_lens[time_axis]; - // this condition may still work even if the inputted shape was incorrect + // this condition may still work if sequence_len's shape was incorrect if(sequence_lens.size() != batch_size) { MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape"); @@ -81,36 +80,32 @@ struct parse_reversesequence : op_parser instruction_ref ret; + auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) { + return info.add_instruction(make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b, t_start}}, + {"ends", {b + 1, t_end}}}), + input); + }; + for(int b = 0; b < batch_size; ++b) { instruction_ref s0; if(sequence_lens[b] > 1) { - s0 = info.add_instruction(make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, 0}}, - {"ends", {b + 1, sequence_lens[b]}}}), - input); + s0 = add_slice(b, 0, sequence_lens[b]); s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0); // if reversed less than whole batch, concat rest of batch if(sequence_lens[b] < time_size) { - auto s1 = info.add_instruction(make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, sequence_lens[b]}}, - {"ends", {b + 1, time_size}}}), - input); + auto s1 = add_slice(b, sequence_lens[b], time_size); s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1); } } else { // cases where nothing changes - s0 = info.add_instruction(make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, 0}}, - {"ends", {b + 1, time_size}}}), - input); + s0 = add_slice(b, 0, time_size); } if(b == 0) { diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 0f5472accf3..59c3acf457a 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -4239,26 +4239,22 @@ TEST_CASE(reversesequence_batch_test) int batch_size = sx.lens()[batch_axis]; int time_size = sx.lens()[time_axis]; - auto ret = mm->add_instruction( - migraphx::make_op( - "slice", - {{"axes", {batch_axis, time_axis}}, {"starts", {0, 0}}, {"ends", {1, time_size}}}), - input); + auto add_slice = + [&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) { + return mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b_start, t_start}}, + {"ends", {b_end, t_end}}}), + input); + }; + auto ret = add_slice(0, 1, 0, time_size); for(int b = 1; b < batch_size; ++b) { - auto s0 = mm->add_instruction(migraphx::make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, 0}}, - {"ends", {b + 1, sequence_lens[b]}}}), - input); + auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]); s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0); if(sequence_lens[b] < time_size) { - auto s1 = mm->add_instruction(migraphx::make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, sequence_lens[b]}}, - {"ends", {b + 1, time_size}}}), - input); + auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size); s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1); } ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); @@ -4310,22 +4306,23 @@ TEST_CASE(reversesequence_time_test) int time_size = sx.lens()[time_axis]; std::vector sequence_lens = {4, 3, 2, 1}; + auto add_slice = + [&mm, &input, batch_axis, time_axis](int b_start, int b_end, int t_start, int t_end) { + return mm->add_instruction(migraphx::make_op("slice", + {{"axes", {batch_axis, time_axis}}, + {"starts", {b_start, t_start}}, + {"ends", {b_end, t_end}}}), + input); + }; + migraphx::instruction_ref ret; for(int b = 0; b < batch_size - 1; ++b) { - auto s0 = mm->add_instruction(migraphx::make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, 0}}, - {"ends", {b + 1, sequence_lens[b]}}}), - input); + auto s0 = add_slice(b, b + 1, 0, sequence_lens[b]); s0 = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {time_axis}}}), s0); if(sequence_lens[b] < time_size) { - auto s1 = mm->add_instruction(migraphx::make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {b, sequence_lens[b]}}, - {"ends", {b + 1, time_size}}}), - input); + auto s1 = add_slice(b, b + 1, sequence_lens[b], time_size); s0 = mm->add_instruction(migraphx::make_op("concat", {{"axis", time_axis}}), s0, s1); } if(b == 0) @@ -4337,11 +4334,7 @@ TEST_CASE(reversesequence_time_test) ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); } } - auto s0 = mm->add_instruction(migraphx::make_op("slice", - {{"axes", {batch_axis, time_axis}}, - {"starts", {batch_size - 1, 0}}, - {"ends", {batch_size, time_size}}}), - input); + auto s0 = add_slice(batch_size - 1, batch_size, 0, time_size); ret = mm->add_instruction(migraphx::make_op("concat", {{"axis", batch_axis}}), ret, s0); mm->add_return({ret}); From d54848831ee95e0f4587141fc9e171cf85ec0432 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 22 Apr 2022 12:00:19 -0500 Subject: [PATCH 4/4] Fix gen_onnx typo --- test/onnx/gen_onnx.py | 4 ++-- .../reversesequence_same_axis_err_test.onnx | 7 +++---- .../reversesequence_time_axis_err_test.onnx | Bin 244 -> 229 bytes 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 908b269cf99..4521856a57c 100755 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -4480,7 +4480,7 @@ def reversesequence_same_axis_err_test(): node = onnx.helper.make_node( 'ReverseSequence', - inputs=['x', 'sequence_lens'], + inputs=['x'], outputs=['y'], time_axis=1, batch_axis=1, @@ -4496,7 +4496,7 @@ def reversesequence_time_axis_err_test(): node = onnx.helper.make_node( 'ReverseSequence', - inputs=['x', 'sequence_lens'], + inputs=['x'], outputs=['y'], time_axis=3, batch_axis=0, diff --git a/test/onnx/reversesequence_same_axis_err_test.onnx b/test/onnx/reversesequence_same_axis_err_test.onnx index cea01ec7d60..10cd74336b0 100644 --- a/test/onnx/reversesequence_same_axis_err_test.onnx +++ b/test/onnx/reversesequence_same_axis_err_test.onnx @@ -1,7 +1,6 @@ -"reversesequence_same_axis_err_test:· -g -x - sequence_lensy"ReverseSequence* +"reversesequence_same_axis_err_test:¨ +X +xy"ReverseSequence* batch_axis * sequence_lens@@@@ * diff --git a/test/onnx/reversesequence_time_axis_err_test.onnx b/test/onnx/reversesequence_time_axis_err_test.onnx index 81fd08c44e7fba03fbfae93ba026b014f3e5e0ed..5768922d8283580043cc8b69b98a01757220f3a2 100644 GIT binary patch delta 17 Ycmeyu_>^&iI{OYrt_Uv1iix({0X8)Sy#N3J delta 32 ncmaFL_=RzTy4Z0>u5>QO3NGH_)WXu#yyVpQoYcJHiF(@sv2Y7q