Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX add support coverage for Reshape and lstm (#20246)
Browse files Browse the repository at this point in the history
* lstm and reshape

* fix sanity

* fix sanity

* reduce state_size

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu authored May 6, 2021
1 parent e329e84 commit c3a87e7
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 97 deletions.
182 changes: 138 additions & 44 deletions python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,23 @@ def convert_reshape(node, **kwargs):
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]

if targ_shape == [-3, -1] and reverse != 'True':
special_case = True
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([2], name+'_2', kwargs['initializer'])
create_tensor([-1], name+'_-1', kwargs['initializer'])
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Slice', [name+'_shape', name+'_0',
name+'_1'], [name+'_1st_dim']),
make_node('Slice', [name+'_shape', name+'_1',
name+'_2'], [name+'_2nd_dim']),
make_node('Mul', [name+'_1st_dim', name+'_2nd_dim'], [name+'_mul']),
make_node('Concat', [name+'_mul', name+'_-1'], [name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name),
]

if special_case:
return nodes

Expand Down Expand Up @@ -4450,17 +4467,14 @@ def convert_RNN(node, **kwargs):
from onnx import TensorProto

name, input_nodes, attrs = get_inputs(node, kwargs)
mode = str(attrs.get('mode'))

bidirectional = str(attrs.get('bidirectional', 'False'))
if bidirectional != 'False':
if bidirectional != 'False' and mode not in ['lstm']:
raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')

num_layers = int(attrs.get('num_layers', '1'))

p = float(attrs.get('p', '0'))
if p != 0:
raise NotImplementedError('Currently RNN onnx export only supports p equals to 0')

use_sequence_length = str(attrs.get('use_sequence_length', 'False'))
if use_sequence_length != 'False':
raise NotImplementedError('Currently RNN onnx export only supports use_sequence_length equals to False')
Expand All @@ -4481,10 +4495,11 @@ def convert_RNN(node, **kwargs):
nodes = []
create_tensor([0], name+'_0', kwargs['initializer'])

mode = str(attrs.get('mode'))
if mode == 'lstm':
initial_c = input_nodes[3]
if num_layers == 2:
if bidirectional != 'False':
raise NotImplementedError('Currently RNN onnx export only supports bidirectional is False')
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_WR_shape', kwargs['initializer'])
Expand Down Expand Up @@ -4556,45 +4571,124 @@ def convert_RNN(node, **kwargs):
make_node('Concat', [name+'_lstm0_c', name+'_lstm1_c'], [name+'2'], axis=0),
]
elif num_layers == 1:
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])
if bidirectional == 'False':
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'], [name+'_seq_length', name+'_batch_size', name+'_input_size']),
# get W
make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'], [name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
# get R
make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
# get B
make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
[name+'0_', name+'1', name+'2'], hidden_size=state_size),
make_node('Squeeze', [name+'0_'], [name], axes=[1]),
]
nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'],
[name+'_seq_length', name+'_batch_size', name+'_input_size']),
# get W
make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'], [name+'_W_'], axis=0),
make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'],
[name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W']),
# get R
make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R']),
# get B
make_node('Add', [name+'_add0', name+'_8*state_size'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_B_1d']),
make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
[name+'0_', name+'1', name+'2'], hidden_size=state_size),
make_node('Squeeze', [name+'0_'], [name], axes=[1]),
]
else:
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor([-1], name+'_-1', kwargs['initializer'])
create_tensor([4*state_size], name+'_4*state_size', kwargs['initializer'])
create_tensor([8*state_size], name+'_8*state_size', kwargs['initializer'])
create_tensor([4*state_size*state_size], name+'_4*state_size^2', kwargs['initializer'])
create_tensor([1, 4*state_size, state_size], name+'_R_shape', kwargs['initializer'])
create_tensor([1, 8*state_size], name+'_B_shape', kwargs['initializer'])

nodes += [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Split', [name+'_data_shape'],
[name+'_seq_length', name+'_batch_size', name+'_input_size']),
# get W_fwd
make_node('Mul', [name+'_4*state_size', name+'_input_size'], [name+'_mul0']),
make_node('Slice', [param, name+'_0', name+'_mul0'], [name+'_W_1d']),
make_node('Split', [name+'_W_1d'], [name+'_W0', name+'_W1', name+'_W2', name+'_W3']),
make_node('Concat', [name+'_W0', name+'_W3', name+'_W1', name+'_W2'],
[name+'_W_'], axis=0),
make_node('Concat', [name+'_1', name+'_4*state_size', name+'_input_size'],
[name+'_W_shape'], axis=0),
make_node('Reshape', [name+'_W_', name+'_W_shape'], [name+'_W_fwd']),
# get R_fwd
make_node('Add', [name+'_mul0', name+'_4*state_size^2'], [name+'_add0']),
make_node('Slice', [param, name+'_mul0', name+'_add0'], [name+'_R_1d']),
make_node('Split', [name+'_R_1d'], [name+'_R0', name+'_R1', name+'_R2', name+'_R3']),
make_node('Concat', [name+'_R0', name+'_R3', name+'_R1', name+'_R2'], [name+'_R_'], axis=0),
make_node('Reshape', [name+'_R_', name+'_R_shape'], [name+'_R_fwd']),
# get W_bwd
make_node('Add', [name+'_add0', name+'_mul0'], [name+'_add1']),
make_node('Slice', [param, name+'_add0', name+'_add1'], [name+'_W_1d_bwd']),
make_node('Split', [name+'_W_1d_bwd'],
[name+'_W0_bwd', name+'_W1_bwd', name+'_W2_bwd', name+'_W3_bwd']),
make_node('Concat', [name+'_W0_bwd', name+'_W3_bwd', name+'_W1_bwd', name+'_W2_bwd'],
[name+'_W_bwd_'], axis=0),
make_node('Reshape', [name+'_W_bwd_', name+'_W_shape'], [name+'_W_bwd']),
# get R_bwd
make_node('Add', [name+'_add1', name+'_4*state_size^2'], [name+'_add2']),
make_node('Slice', [param, name+'_add1', name+'_add2'], [name+'_R_1d_bwd']),
make_node('Split', [name+'_R_1d_bwd'],
[name+'_R0_bwd', name+'_R1_bwd', name+'_R2_bwd', name+'_R3_bwd']),
make_node('Concat', [name+'_R0_bwd', name+'_R3_bwd', name+'_R1_bwd', name+'_R2_bwd'],
[name+'_R_bwd_'], axis=0),
make_node('Reshape', [name+'_R_bwd_', name+'_R_shape'], [name+'_R_bwd']),
# get B_fwd
make_node('Add', [name+'_add2', name+'_8*state_size'], [name+'_add3']),
make_node('Slice', [param, name+'_add2', name+'_add3'], [name+'_B_1d']),
make_node('Split', [name+'_B_1d'], [name+'_B0', name+'_B1', name+'_B2', name+'_B3',
name+'_B4', name+'_B5', name+'_B6', name+'_B7']),
make_node('Concat', [name+'_B0', name+'_B3', name+'_B1', name+'_B2',
name+'_B4', name+'_B7', name+'_B5', name+'_B6'], [name+'_B_'], axis=0),
make_node('Reshape', [name+'_B_', name+'_B_shape'], [name+'_B_fwd']),
# get B_bwd
make_node('Add', [name+'_add3', name+'_8*state_size'], [name+'_add4']),
make_node('Slice', [param, name+'_add3', name+'_add4'], [name+'_B_1d_bwd']),
make_node('Split', [name+'_B_1d_bwd'],
[name+'_B0_bwd', name+'_B1_bwd', name+'_B2_bwd', name+'_B3_bwd',
name+'_B4_bwd', name+'_B5_bwd', name+'_B6_bwd', name+'_B7_bwd']),
make_node('Concat', [name+'_B0_bwd', name+'_B3_bwd', name+'_B1_bwd', name+'_B2_bwd',
name+'_B4_bwd', name+'_B7_bwd', name+'_B5_bwd', name+'_B6_bwd'],
[name+'_B_bwd_'], axis=0),
make_node('Reshape', [name+'_B_bwd_', name+'_B_shape'], [name+'_B_bwd']),
# get seq_len
make_node('Tile', [name+'_seq_length', name+'_batch_size'], [name+'_seq_len_']),
make_node("Cast", [name+'_seq_len_'], [name+"_seq_len"], to=int(TensorProto.INT32)),
# compute LSTM
make_node('Concat', [name+'_W_fwd', name+'_W_bwd'], [name+'_W'], axis=0),
make_node('Concat', [name+'_R_fwd', name+'_R_bwd'], [name+'_R'], axis=0),
make_node('Concat', [name+'_B_fwd', name+'_B_bwd'], [name+'_B'], axis=0),
make_node('LSTM', [data, name+'_W', name+'_R', name+'_B', name+'_seq_len', initial_h, initial_c],
[name+'0_', name+'1', name+'2'], hidden_size=state_size, direction='bidirectional'),
make_node('Transpose', [name+'0_'], [name+'0_t'], perm=[0, 2, 1, 3]),
make_node('Concat', [name+'_seq_length', name+'_batch_size', name+'_-1'],
[name+'_shape_out'], axis=0),
make_node('Reshape', [name+'0_t', name+'_shape_out'], [name]),
]
else:
raise NotImplementedError('Currently RNN onnx export only supports num_layers equals to 1 or 2')

Expand Down
Loading

0 comments on commit c3a87e7

Please sign in to comment.