diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index 2f38faa05f2b..a3c3c79809ed 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -4440,6 +4440,7 @@ def convert_sequence_reverse(node, **kwargs): return nodes + @mx_op.register("RNN") def convert_RNN(node, **kwargs): """Map MXNet's RNN operator attributes to onnx's operators @@ -4810,6 +4811,7 @@ def convert_RNN(node, **kwargs): raise NotImplementedError(f"Currently RNN onnx export does not support {mode} mode") return nodes + @mx_op.register('_rnn_param_concat') def convert_rnn_param_concat(node, **kwargs): """Map MXNet's _rnn_param_concat operator @@ -4852,3 +4854,36 @@ def convert_contrib_div_sqrt_dim(node, **kwargs): ] return nodes + + +@mx_op.register('_split_v2') +def convert_contrib_split_v2(node, **kwargs): + """Map MXNet's _split_v2 operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + axis = int(attrs.get('axis', 0)) + squeeze_axis = attrs.get('squeeze_axis', 'False') + sections = int(attrs.get('sections', 0)) + indices = convert_string_to_list(attrs.get('indices', '[]')) + if sections <= 0 and len(indices) == 0: + raise NotImplementedError('section or indices must be set') + if sections > 0: + output_nodes = [name+str(i) for i in range(sections)] + if squeeze_axis == 'False': + nodes = [ + make_node('Split', input_nodes, output_nodes, axis=axis), + ] + else: + output_nodes_ = [name+str(i)+'_' for i in range(sections)] + nodes = [ + make_node('Split', input_nodes, output_nodes_, axis=axis), + ] + for i in range(sections): + nodes += [ + make_node("Squeeze", [output_nodes_[i]], [output_nodes[i]], axes=[axis]), + ] + else: + raise NotImplementedError('indices is supported since ONNX 1.8.0 (opset13), please upgrade ONNX version') + + return nodes diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py index 4ac6dfdff21c..95bb27c7afcc 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset13.py @@ -136,7 +136,6 @@ def get_inputs(node, kwargs): outputs_lookup = kwargs["outputs_lookup"] inputs = node["inputs"] attrs = node.get("attrs", {}) - input_nodes = [] for ip in inputs: input_node_name = outputs_lookup[ip[0]][ip[1]].name @@ -1732,3 +1731,69 @@ def convert_logsoftmax(node, **kwargs): ) return [node] + + +@mx_op.register('_split_v2', OPSET_VERSION) +def convert_contrib_split_v2(node, **kwargs): + """Map MXNet's _split_v2 operator + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + axis = int(attrs.get('axis', 0)) + squeeze_axis = attrs.get('squeeze_axis', 'False') + sections = int(attrs.get('sections', 0)) + indices = convert_string_to_list(attrs.get('indices', '[]')) + if sections <= 0 and len(indices) == 0: + raise NotImplementedError('section or indices must be set') + if sections > 0: + output_nodes = [name+str(i) for i in range(sections)] + if squeeze_axis == 'False': + nodes = [ + make_node('Split', input_nodes, output_nodes, axis=axis), + ] + else: + output_nodes_ = [name+str(i)+'_' for i in range(sections)] + create_tensor([axis], name+'_axis', kwargs['initializer']) + nodes = [ + make_node('Split', input_nodes, output_nodes_, axis=axis), + ] + for i in range(sections): + nodes += [ + make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_nodes[i]]), + ] + else: + indices.sort() + split = [] + for i in range(1, len(indices)): + if indices[i] >= indices[i-1]: + split.append(indices[i] - indices[i-1]) + + output_nodes = [name+str(i) for i in range(len(split)+1)] + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([axis], name+'_axis', kwargs['initializer']) + create_tensor([axis+1], name+'_axis+1', kwargs['initializer']) + create_tensor(split, name+'_split_', kwargs['initializer']) + create_tensor([sum(split)], name+'_sum', kwargs['initializer']) + nodes = [ + make_node('Shape', input_nodes, [name+'_shape']), + make_node('Slice', [name+'_shape', name+'_axis', name+'_axis+1', name+'_0'], [name+'_dim']), + make_node('Sub', [name+'_dim', name+'_sum'], [name+'_sub']), + make_node('Concat', [name+'_split_', name+'_sub'], [name+'_concat'], axis=0), + make_node('Less', [name+'_concat', name+'_0'], [name+'_less']), + make_node('Where', [name+'_less', name+'_0', name+'_concat'], [name+'_split']), + ] + if squeeze_axis == 'False': + nodes += [ + make_node('Split', [input_nodes[0], name+'_split'], output_nodes, axis=axis), + ] + else: + output_nodes_ = [name+str(i)+'_' for i in range(len(split)+1)] + nodes += [ + make_node('Split', [input_nodes[0], name+'_split'], output_nodes_, axis=axis), + ] + for i, output_node in enumerate(output_nodes): + nodes += [ + make_node("Squeeze", [output_nodes_[i], name+'_axis'], [output_node]), + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index a2971b8f139d..cdbebb0d3a93 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1858,3 +1858,19 @@ def rand_check(out): def rand_check_nd(out): return rand_check(out.asnumpy()) op_export_test('sample_multinomial', M, [x], tmp_path, mx_map=rand_check_nd, onnx_map=rand_check) + + +@pytest.mark.parametrize("dtype", ['float32', 'int32', 'int64']) +@pytest.mark.parametrize('params', [((2, 4, 6), (1, ), 0, True), + ((4, 5, 6), (2, 4), 1, False), + ((4, 5, 6, 7), (0, 2, 4), 2, False), + ((4, 5, 6, 7), 3, -2, False), + ((2, 6, 8), 8, -1, True)]) +def test_onnx_export_split_v2(tmp_path, dtype, params): + from onnx.defs import onnx_opset_version + if onnx_opset_version() < 13 and not isinstance(params[1], int): + # opset12 only supports sections. indices is supported since opset13 + return + M = def_model('split_v2', indices_or_sections=params[1], axis=params[2], squeeze_axis=params[3]) + x = mx.nd.random.uniform(0, 10, params[0]).astype(dtype) + op_export_test('split_v2', M, [x], tmp_path)