diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index a1d5320d9f94..f05ec7f9b73d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3275,3 +3275,78 @@ def convert_gather_nd(node, **kwargs): ] return nodes + + +@mx_op.register("batch_dot") +def convert_batch_dot(node, **kwargs): + """Map MXNet's gather_ND operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + lhs = input_nodes[0] + rhs = input_nodes[1] + input_type = kwargs['in_type'] + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + transpose_a = str(attrs.get('transpose_a', '0')) + transpose_b = str(attrs.get('transpose_b', '0')) + perm = [0, 2, 1] + + nodes = [ + create_tensor([-2], name+'_-2', kwargs['initializer']), + create_tensor([-1], name+'_-1', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([2], name+'_2', kwargs['initializer']), + create_tensor([100], name+'_100', kwargs['initializer']), + create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype), + ] + + if transpose_a in ['0', 'False']: + nodes += [ + make_node('Add', [lhs, name+'_0f'], [name+'_lhs']), + ] + else: + nodes += [ + make_node('Shape', [lhs], [name+'_lhs_shape']), + make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), + # make_node('Sub', [name+'_lhs_dim', name+'_2'], [name+'_lhs_sub']), + make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'], + [name+'_lhs_slice0']), + make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'], + [name+'_lhs_slice1']), + make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0), + make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']), + make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']), + make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'], + [name+'_lhs_slice2']), + make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0), + make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), + ] + + if transpose_b in ['0', 'False']: + nodes += [ + make_node('Add', [rhs, name+'_0f'], [name+'_rhs']), + ] + else: + nodes += [ + make_node('Shape', [rhs], [name+'_rhs_shape']), + make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']), + make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'], + [name+'_rhs_slice0']), + make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'], + [name+'_rhs_slice1']), + make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0), + make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']), + make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']), + make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'], + [name+'_rhs_slice2']), + make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0), + make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']), + ] + + nodes += [ + make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]), + ] + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 2c363a96ba04..afaec5c72164 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -539,3 +539,16 @@ def test_onnx_export_gather_nd(tmp_path, dtype): M2 = def_model('gather_nd') op_export_test('gather_nd2', M2, [x2, y2], tmp_path) + +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('transpose_a', [True, False]) +@pytest.mark.parametrize('transpose_b', [True, False]) +def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b): + x1 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 6), dtype=dtype) + y1 = mx.nd.random.normal(0, 10, (2, 3, 4, 6, 5), dtype=dtype) + M1 = def_model('batch_dot') + op_export_test('batch_dot1', M1, [x1, y1], tmp_path) + x2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype) + y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype) + M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b) + op_export_test('batch_dot2', M2, [x2, y2], tmp_path)