From 98cb0b31c7a3a85c2fefe7cad4d737b67f702cd7 Mon Sep 17 00:00:00 2001 From: waytrue17 <52505574+waytrue17@users.noreply.github.com> Date: Tue, 2 Mar 2021 17:23:22 -0800 Subject: [PATCH] [v1.x] ONNX export support for RNN (#19958) * convert RNN * use split * fix sanity * fix param * fix sanity * fix space * add note Co-authored-by: Wei Chu --- .../contrib/onnx/mx2onnx/_op_translations.py | 87 +++++++++++++++++++ .../mxnet/contrib/onnx/mx2onnx/export_onnx.py | 6 +- tests/python-pytest/onnx/test_operators.py | 14 +++ 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 2521cf5adb37..799957675df7 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -4105,3 +4105,90 @@ def convert_sequence_reverse(node, **kwargs): ] return nodes + + +@mx_op.register("RNN") +def convert_RNN(node, **kwargs): + """Map MXNet's RNN operator attributes to onnx's operators + and return the created node. + """ + from onnx.helper import make_node + from onnx import TensorProto + + name, input_nodes, attrs = get_inputs(node, kwargs) + + mode = str(attrs.get('mode')) + if mode != 'lstm': + raise NotImplementedError('Currently RNN onnx export only supports lstm mode') + + bidirectional = str(attrs.get('bidirectional', 'False')) + if bidirectional != 'False': + raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False') + + num_layers = int(attrs.get('num_layers', '1')) + if num_layers != 1: + raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1') + + p = float(attrs.get('p', '0')) + if p != 0: + raise NotImplementedError('Currently RNN onnx export only supports p equals to 0') + + use_sequence_length = str(attrs.get('use_sequence_length', 'False')) + if use_sequence_length != 'False': + raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False') + + projection_size = str(attrs.get('projection_size', 'None')) + if projection_size != 'None': + raise NotImplementedError('Currently RNN onnx export only supports projection_size equals to None') + + state_outputs = str(attrs.get('state_outputs', 'False')) + if state_outputs != 'True': + raise NotImplementedError('Currently RNN onnx export only supports state_outputs equals to True') + + state_size = int(attrs.get('state_size')) + data = input_nodes[0] + param = input_nodes[1] + initial_h = input_nodes[2] + initial_c = input_nodes[3] + + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer']) + create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer']) + create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer']) + create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer']) + create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer']) + + nodes = [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']), + # get W + make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']), + make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']), + make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']), + make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0), + make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0), + make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']), + # get R + make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']), + make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']), + make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']), + make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0), + make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']), + # get B + make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']), + make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']), + make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3', + name+'_B4', name+'_B5', name+'_B6', name+'_B7']), + make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2', + name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0), + make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']), + # get seq_len + make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']), + make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)), + # compute LSTM + make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c], + [name+'0_', name+'1', name+'2'], hidden_size=state_size), + make_node('Squeeze', [name+'0_'], [name], axes=[1]), + ] + return nodes diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index af6af8b738a7..898a8df2d5c2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -141,7 +141,11 @@ def get_outputs(sym, params, in_shape, in_label, in_type): out_names = list() for name in sym.list_outputs(): - if name.endswith('_output'): + if name.endswith('_state_output'): # handel special cases for RNN operator + out_names.append(name[:-len('_state_output')]+'1') + elif name.endswith('_statecell_output'): # handel special cases for RNN operator + out_names.append(name[:-len('_statecell_output')]+'2') + elif name.endswith('_output'): out_names.append(name[:-len('_output')]) elif re.search('.*_output[0-9]$', name): out_names.append(name[:-len('_output0')]+name[-1]) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index f73672b59cc7..8ebfbab2bb47 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1204,3 +1204,17 @@ def test_onnx_export_sequence_reverse(tmp_path, dtype, params): seq_len = mx.nd.array(params[1]) M1 = def_model('SequenceReverse', use_sequence_length=True) op_export_test('SequenceReverse1', M1, [x, seq_len], tmp_path) + + +# onnx LSTM from opset 11 does not support float64 +@pytest.mark.parametrize('dtype', ['float32']) +@pytest.mark.parametrize('state_size', [128, 256, 512]) +def test_onnx_export_RNN(tmp_path, dtype, state_size): + # the current implementation fails assertion checks for large parm/state_size. + M = def_model('RNN', mode='lstm', state_size=state_size, state_outputs=True, num_layers=1, p=0) + x = mx.nd.random.normal(0, 10, (38, 1, 300), dtype=dtype) + batch_size = np.shape(x)[1] + input_size = np.shape(x)[2] + param = mx.nd.random.normal(0, 1, [4*state_size*input_size + 4*state_size*state_size + 8*state_size], dtype=dtype) + state = mx.nd.random.uniform(-1, 1, [1, batch_size, state_size], dtype=dtype) + cell = mx.nd.random.uniform(-1, 1, [1, batch_size, state_size], dtype=dtype)