diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8ba578f7b300..09dcf88fe92c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1909,7 +1909,7 @@ def convert_roipooling(node, **kwargs): return [node] -@mx_op.register("Tile") +@mx_op.register("tile") def convert_tile(node, **kwargs): """Map MXNet's Tile operator attributes to onnx's Tile operator and return the created node. diff --git a/tests/python-pytest/onnx/test_node.py b/tests/python-pytest/onnx/test_node.py index 25fe9c9f9a51..ce22b8320675 100644 --- a/tests/python-pytest/onnx/test_node.py +++ b/tests/python-pytest/onnx/test_node.py @@ -31,7 +31,7 @@ from collections import namedtuple import numpy as np import numpy.testing as npt -from onnx import numpy_helper, helper, load_model +from onnx import checker, numpy_helper, helper, load_model from onnx import TensorProto from mxnet.test_utils import download from mxnet.contrib import onnx as onnx_mxnet @@ -206,6 +206,18 @@ def test_imports(self): mxnet_out = bkd_rep.run(inputs) npt.assert_almost_equal(np_out, mxnet_out, decimal=4) + def test_exports(self): + input_shape = (2,1,3,1) + for test in export_test_cases: + test_name, onnx_name, mx_op, attrs = test + input_sym = mx.sym.var('data') + outsym = mx_op(input_sym, **attrs) + converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32, + onnx_file_path=outsym.name + ".onnx") + model = load_model(converted_model) + checker.check_model(model) + + # test_case = ("test_case_name", mxnet op, "ONNX_op_name", [input_list], attribute map, MXNet_specific=True/False, # fix_attributes = {'modify': {mxnet_attr_name: onnx_attr_name}, # 'remove': [attr_name], @@ -274,5 +286,10 @@ def test_imports(self): ("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1}) ] +export_test_cases = [ + ("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}), + ("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)}) +] + if __name__ == '__main__': unittest.main()