-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Numpy] Bugfix of slice operator export (MXNet to ONNX) #17827
Conversation
Hi, there is actually an additional problem: converted = MXNetGraph.convert_layer(
node,
is_input=False,
mx_graph=mx_graph,
weights=weights,
in_shape=in_shape,
in_type=in_type,
proc_nodes=all_processed_nodes,
initializer=initializer,
index_lookup=index_lookup,
idx=idx
) I suggest to infer the current shape before each layer and to refactor |
@QueensGambit I am personally interested in fixing this issue. I've also developed a fix for the P.S. You might also want to rebase on top of current master, since I think, there were some CI/CD changes recently, so maybe that will improve the problems with CI validation. |
Hello @ruro, great to see that you have been working on the shape inference issue. display(mx.viz.plot_network(
symbol,
shape={'data':(input_shape)},
node_attrs={"shape":"oval","fixedsize":"false"}
)) The corresponding code for inferring the shapes looks like this: internals = symbol.get_internals()
input_name = "data"
_, out_shapes, _ = internals.infer_shape(**{input_name: input_shape}) This way you would only require |
@QueensGambit I am not quite sure, what you mean by The approach I use in my PR extracts the shapes in exactly the way you describe. The graph_shapes = MXNetGraph.get_outputs(sym.get_internals(), params, in_shape, output_label) by calling get_outputs to effectively do exactly _, out_shapes, _ = internals.infer_shape(**inputs) P.S. You are free to change the code in my PR for |
@ruro see QueensGambit#1 for reply on your PR. |
@QueensGambit I am interested in getting this fix accepted. Could you solve the merge conflict and rebase on top of master? The conflict is in the unittest, but it's pretty easy to fix. Just add a A patch like this should work: diff --git a/tests/python/unittest/onnx/mxnet_export_test.py b/tests/python/unittest/onnx/mxnet_export_test.py
index 40d7d4e3e..3b039b426 100644
--- a/tests/python/unittest/onnx/mxnet_export_test.py
+++ b/tests/python/unittest/onnx/mxnet_export_test.py
@@ -28,6 +28,7 @@ from common import setup_module, teardown_module, with_seed
from mxnet import nd, sym
from mxnet.test_utils import set_default_context
from mxnet.gluon import nn
+from mxnet.gluon import HybridBlock
from mxnet.contrib import onnx as onnx_mxnet
import mxnet as mx
@@ -80,6 +81,16 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
mx.test_utils.assert_almost_equal(out, imp_out, atol=1e-5, rtol=1e-5)
+class SplitConcatBlock(HybridBlock):
+ """Block which creates two splits and later concatenates them"""
+ def __init__(self, name):
+ super(SplitConcatBlock, self).__init__(name)
+
+ def hybrid_forward(self, F, x):
+ splits = F.split(x, axis=1, num_outputs=2)
+ return F.concat(*splits)
+
+
class TestExport(unittest.TestCase):
""" Tests ONNX export.
"""
@@ -126,3 +137,10 @@ class TestExport(unittest.TestCase):
net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
_check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])})
+ @with_seed()
+ def test_onnx_export_slice(self):
+ net = nn.HybridSequential(prefix='slice_net')
+ with net.name_scope():
+ net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), nn.Dense(10))
+ _check_onnx_export(net)
+ |
@ruro I added a verbose parameter to graph_shapes = MXNetGraph.get_outputs(sym.get_internals(), params, in_shape, output_label, verbose=False) I also added another unit test for using the slice operation with changing shapes. |
Yeah. Don't worry. I am guessing, you merged instead of rebasing. What I meant is to do |
09eee04
to
79ba0e5
Compare
@ruro Ok, thank you. I was able to fix this mess by going back and using a forced push. |
Unfortunately, it seems, that this doesn't clear the codeowners, so the PR still requires the review of 10 different maintainers. You'll probably have to create a new PR after all. :( The last "added ONNX unit test export" commit doesn't seem correct, since it updates the submodules. I think, you can just drop this commit, since there are no actual changes left in it apart from the accidental submodules update. |
79ba0e5
to
0c6411f
Compare
Description
This PR fixes the slice operator export from MXNet into ONNX.
@MoritzMaxeiner reported this problem already: Incorrect ONNX export of SliceChannel #13061.
The corresponding pull request ONNX export: Support equal length splits #14121 however, only partially solved the problem.
The output of the slice operator was corrected but it is still not properly handled in the
get_inputs()
function.The following unit test demonstrates this:
Before this PR, the conversion of the model from MXNet into ONNX resulted in the following error:
Now, this is resolved by replacing:
with:
Ping: @Roshrini @vandanavk @ChaiBapchya
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
test_onnx_export_slice()
get_inputs()
Comments