diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ffeb0dd731713..1e02910c01698 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2114,6 +2114,16 @@ def _impl_v9(cls, inputs, attr, params): return _op.transpose(output, axes=(1, 0)) + +class ReverseSequence(OnnxOpConverter): + """Operator converter for ReverseSequence""" + + @classmethod + def _impl_v10(cls, inputs, attr, params): + + return _op.reverse_sequence(inputs[0], inputs[1], attr["time_axis"], attr["batch_axis"]) + + class TopK(OnnxOpConverter): """Operator converter for TopK""" @@ -2801,6 +2811,7 @@ def _get_convert_map(opset): "QuantizeLinear": QuantizeLinear.get_converter(opset), "DequantizeLinear": DequantizeLinear.get_converter(opset), "DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset), + "ReverseSequence": ReverseSequence.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 6d22b5afd0df4..5ddd353c6d550 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4328,6 +4328,51 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None verify_embedding_bag(32, 2, [3, 3]) +def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): + node = onnx.helper.make_node( + "ReverseSequence", + inputs=["x", "sequence_lens"], + outputs=["y"], + time_axis=time_axis, + batch_axis=batch_axis, + ) + + graph = helper.make_graph( + [node], + "reverse_sequence_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape)), + helper.make_tensor_value_info( + "sequence_lens", TensorProto.INT64, list(sequence_lens.shape) + ), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(x.shape))], + ) + + model = helper.make_model(graph, producer_name="reverse_sequence_test") + verify_with_ort_with_inputs(model, [x, sequence_lens], list(x.shape)) + + +@tvm.testing.uses_gpu +def test_reverse_sequence(): + x = np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], + dtype=np.float32, + ) + sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64) + y = np.array( + [[0, 5, 10, 15], [4, 1, 6, 11], [8, 9, 2, 7], [12, 13, 14, 3]], + dtype=np.float32, + ) + verify_reverse_sequence(x, sequence_lens, 0, 1) + + y = np.array( + [[0, 1, 2, 3], [5, 4, 6, 7], [10, 9, 8, 11], [15, 14, 13, 12]], + dtype=np.float32, + ) + verify_reverse_sequence(x, sequence_lens, 1, 0) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4407,4 +4452,8 @@ def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None test_softplus() test_cumsum() test_wrong_input() +<<<<<<< HEAD test_aten() +======= + test_reverse_sequence() +>>>>>>> 726b946b9... add onnx reverse sequence op