Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add Flatten before Gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Nov 26, 2018
1 parent e9c8db7 commit 8cd959d
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):

fcnode = []

op_name = "flatten_" + str(kwargs["idx"])
flatten_node = onnx.helper.make_node(
'Flatten',
inputs=[input_nodes[0]],
outputs=[op_name],
name=op_name
)

input_nodes[0] = op_name
fcnode.append(flatten_node)

if no_bias:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
bias_name = "bias" + str(kwargs["idx"])
Expand Down

0 comments on commit 8cd959d

Please sign in to comment.