Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MVN in MO #4311

Merged
merged 7 commits into from
Feb 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ extensions/front/caffe/input_ext.py
extensions/front/caffe/interp_ext.py
extensions/front/caffe/lrn_ext.py
extensions/front/caffe/mvn_ext.py
extensions/front/caffe/MVNNormalizer.py
extensions/front/caffe/MVNCaffeToMVN.py
extensions/front/caffe/normalize_ext.py
extensions/front/caffe/permute_ext.py
extensions/front/caffe/pooling_ext.py
Expand Down Expand Up @@ -294,6 +294,7 @@ extensions/front/onnx/mask_rcnn.json
extensions/front/onnx/mask_rcnn_conversion.py
extensions/front/onnx/matmul_ext.py
extensions/front/onnx/mean_variance_normalization_ext.py
extensions/front/onnx/MvnOnnxToMvn.py
extensions/front/onnx/non_max_suppression_ext.py
extensions/front/onnx/non_zero_ext.py
extensions/front/onnx/normalize_ext.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

from extensions.front.caffe.MVNNormalizer import MVNCaffeToMVN
from extensions.front.caffe.MVNCaffeToMVN import MVNCaffeToMVN
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front

Expand All @@ -37,7 +37,7 @@
}


class MVNNormalizerTest(unittest.TestCase):
class MVNCaffeToMVNTest(unittest.TestCase):
def test_mvn_normalizer(self):
graph = build_graph(nodes, [('input', 'mvn_caffe'),
('mvn_caffe', 'output')],
Expand Down
9 changes: 7 additions & 2 deletions model-optimizer/extensions/front/caffe/mvn_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.mvn import MVNCaffe
from extensions.ops.mvn import AttributedMVN
from mo.front.caffe.collect_attributes import collect_attributes
from mo.front.extractor import FrontExtractorOp

Expand All @@ -29,6 +29,11 @@ def extract(cls, node):

attrs = collect_attributes(param)

if 'normalize_variance' not in attrs:
attrs['normalize_variance'] = 1
if 'across_channels' not in attrs:
attrs['across_channels'] = 0

# update the attributes of the node
MVNCaffe.update_node_stat(node, attrs)
AttributedMVN.update_node_stat(node, attrs)
return cls.enabled
40 changes: 40 additions & 0 deletions model-optimizer/extensions/front/onnx/MvnOnnxToMvn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Copyright (C) 2017-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.mvn import MVN
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes


class MvnOnnxToMvn(FrontReplacementPattern):
"""
Replace AttributedMVN operation from ONNX with MVN
"""
enabled = True

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='MVNOnnx'):
node_name = node.soft_get('name', node.id)

new_mvn = create_op_with_const_inputs(graph, MVN, {1: node.axes},
{'eps': node.eps,
'eps_mode': node.eps_mode,
'normalize_variance': node.normalize_variance})
node.in_port(0).get_connection().set_destination(new_mvn.in_port(0))
node.out_port(0).get_connection().set_source(new_mvn.out_port(0))
rename_nodes([(node, node_name + '/to_be_removed'), (new_mvn, node_name)])

graph.remove_node(node.id)
56 changes: 56 additions & 0 deletions model-optimizer/extensions/front/onnx/MvnOnnxToMvn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Copyright (C) 2017-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import numpy as np

from extensions.front.onnx.MvnOnnxToMvn import MvnOnnxToMvn
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op_with_empty_data, result, const, connect_front

nodes = {
**regular_op_with_empty_data('input', {'type': 'Parameter'}),
**regular_op_with_empty_data('mvn_onnx', {'op': 'MVNOnnx',
'axes': int64_array([2, 3]),
'eps': 1e-9,
'eps_mode': 'outside_sqrt',
'normalize_variance': 1}),
**result(),

# nodes after replacement
**const('axes', int64_array([2, 3])),
**regular_op_with_empty_data('mvn', {'op': 'MVN', 'type': None}),
}


class MvnOnnxToMvnTest(unittest.TestCase):
def test_mvn_normalize(self):
graph = build_graph(nodes, [('input', 'mvn_onnx'),
('mvn_onnx', 'output')],
nodes_with_edges_only=True)
graph.stage = 'front'

MvnOnnxToMvn().find_and_replace_pattern(graph)

graph_ref = build_graph(nodes, [('input', 'mvn'),
*connect_front('axes', '1:mvn'),
('mvn', 'output')],
nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import numpy as np

from extensions.ops.mvn import MVN
from extensions.ops.mvn import MVNOnnx
from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.ops.const import Const


class MeanVarianceNormalizationExtractor(FrontExtractorOp):
Expand All @@ -28,20 +28,16 @@ class MeanVarianceNormalizationExtractor(FrontExtractorOp):

@classmethod
def extract(cls, node):
name = node.soft_get('name', node.id)
axes = onnx_attr(node, 'axes', 'ints',
default=np.array([0, 2, 3], dtype=np.int64),
default=int64_array([0, 2, 3]),
dst_type=lambda x: np.array(x, dtype=np.int64))

axes = Const(node.graph, {'value': axes, 'name': name + '/Axes'}).create_node()
node.add_input_port(1, skip_if_exist=True)
node.in_port(1).connect(axes.out_port(0))

attrs = {
'eps': 1e-9,
'normalize_variance': 1,
'eps_mode': 'outside_sqrt'
'axes': axes,
'eps_mode': 'outside_sqrt',
}

MVN.update_node_stat(node, attrs)
MVNOnnx.update_node_stat(node, attrs)
return cls.enabled
4 changes: 2 additions & 2 deletions model-optimizer/extensions/middle/GroupNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

# MVN axes
_, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0), return_as_a_scalar=True)
_, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(),
return_as_a_scalar=True)
rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)},
{'name': group_norm_node.name + '/Range', 'output_type': np.int64})
mvn_node.in_port(1).connect(rng.out_port(0))
rng.in_port(1).connect(rank.out_port(0))
mvn_node.in_port(0).get_connection().add_destination(rank.in_port(0))

# reshape to the initial shape before multiplying with gamma and adding beta
reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
Expand Down
128 changes: 128 additions & 0 deletions model-optimizer/extensions/middle/GroupNorm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from extensions.middle.GroupNorm import GroupNormToMVN
from mo.front.common.partial_infer.utils import float_array, int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, build_graph_with_edge_attrs, connect, \
regular_op_with_shaped_data, valued_const_with_data, connect_data

shape = int64_array([1, 3, 5, 2])
nodes = {**regular_op_with_shaped_data('input', shape, {'type': 'Parameter', 'op': 'Parameter'}),
**valued_const_with_data('gamma', float_array([0.5])),
**valued_const_with_data('beta', float_array([0.5])),
**regular_op_with_shaped_data('group_norm', shape,
{'op': 'GroupNorm', 'name': 'group_norm', 'num_groups': 3, 'eps': 1e-9}),
**result('result')
}

edges = [*connect('input:0', '0:group_norm'),
*connect('gamma', '1:group_norm'),
*connect('beta', '2:group_norm'),
*connect('group_norm:0', 'result'),
]

ref_nodes = {**regular_op_with_shaped_data('input', shape, {'type': 'Parameter', 'op': 'Parameter'}),
**regular_op_with_shaped_data('shape1', int64_array([4]), {'op': 'ShapeOf'}),
**regular_op_with_shaped_data('shape2', int64_array([4]), {'op': 'ShapeOf'}),
**regular_op_with_shaped_data('shape3', int64_array([1]), {'op': 'ShapeOf'}),
**regular_op_with_shaped_data('hcast1', int64_array([4]), {'op': 'Cast'}),
**regular_op_with_shaped_data('cast2', int64_array([2]), {'op': 'Cast'}),
**regular_op_with_shaped_data('cast3', int64_array([4]), {'op': 'Cast'}),
**regular_op_with_shaped_data('gather1', int64_array([2]), {'op': 'Gather'}),
**regular_op_with_shaped_data('gather2', int64_array([1]), {'op': 'Gather'}),
**regular_op_with_shaped_data('gather3', int64_array([1]), {'op': 'Gather'}),
**regular_op_with_shaped_data('mul1', int64_array([1]), {'op': 'Mul'}),
**regular_op_with_shaped_data('mul2', int64_array([1]), {'op': 'Mul'}),
**regular_op_with_shaped_data('mul3', shape, {'op': 'Mul'}),
**regular_op_with_shaped_data('concat', int64_array([4]), {'op': 'Concat'}),
**regular_op_with_shaped_data('reshape1', int64_array([3, 1, 5, 2]), {'op': 'Reshape'}),
**regular_op_with_shaped_data('reshape2', shape, {'op': 'Reshape'}),
**regular_op_with_shaped_data('squeeze', int64_array([]), {'op': 'Squeeze'}),
**regular_op_with_shaped_data('range', int64_array([3]), {'op': 'Range'}),
**regular_op_with_shaped_data('mvn', int64_array([3, 1, 5, 2]), {'op': 'MVN'}),
**regular_op_with_shaped_data('add', shape, {'op': 'Add'}),
**valued_const_with_data('shape/axis1', int64_array(0)),
**valued_const_with_data('shape/ind1', int64_array([2, 3])),
**valued_const_with_data('shape/axis2', int64_array(0)),
**valued_const_with_data('shape/ind2', int64_array([0])),
**valued_const_with_data('shape/axis3', int64_array(0)),
**valued_const_with_data('shape/ind3', int64_array([1])),
**valued_const_with_data('gn/rec', float_array([1./3])),
**valued_const_with_data('group', int64_array([3])),
**valued_const_with_data('squeeze/axis', int64_array([0])),
**valued_const_with_data('range/start', int64_array(1)),
**valued_const_with_data('range/step', int64_array(1)),
**valued_const_with_data('gamma', float_array([[[[0.5]]]])),
**valued_const_with_data('beta', float_array([[[[0.5]]]])),
**result('result')
}
ref_edges = [*connect('input', '0:reshape1'),
*connect('input', 'shape1', skip_data=True),
*connect('shape1:0', '0:gather1'),
*connect('shape1:0', 'hcast1', skip_data=True),
*connect('shape/ind1', '1:gather1'),
*connect('shape/axis1', '2:gather1'),
*connect('gather1', 'cast2'),
*connect('hcast1', '0:gather3'),
*connect('hcast1', '0:gather2', skip_data=True),
*connect('shape/ind2', '1:gather2'),
*connect('shape/axis2', '2:gather2'),
*connect('gather2', '0:mul2'),
*connect('group', '1:mul2'),
*connect('shape/ind3', '1:gather3'),
*connect('shape/axis3', '2:gather3'),
*connect('gather3', '0:mul1'),
*connect('gn/rec', '1:mul1'),
*connect('mul2', '0:concat'),
*connect('mul1', '1:concat'),
*connect('cast2', '2:concat'),
*connect('concat', 'cast3'),
*connect('cast3', '1:reshape1'),
*connect('reshape1', 'shape2'),
*connect('shape2', 'shape3'),
*connect('shape3', '0:squeeze'),
*connect('squeeze/axis', '1:squeeze'),
*connect('range/start', '0:range'),
*connect('squeeze', '1:range'),
*connect('range/step', '2:range'),
*connect('reshape1', '0:mvn', skip_data=True),
*connect('range', '1:mvn'),
*connect('mvn', '0:reshape2'),
*connect('shape1:0', '1:reshape2', skip_data=True),
*connect('reshape2', '0:mul3'),
*connect('gamma', '1:mul3'),
*connect('mul3', '0:add'),
*connect('beta', '1:add'),
*connect('add', 'result')
]


class GroupNormToMVNTest(unittest.TestCase):
def test_group_norm_1(self):
graph = build_graph(nodes, edges)

graph_ref = build_graph(ref_nodes, ref_edges)

graph.graph['layout'] = 'NCHW'

GroupNormToMVN().find_and_replace_pattern(graph)
graph.clean_up()

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
24 changes: 22 additions & 2 deletions model-optimizer/extensions/ops/mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,26 @@ def infer(node: None):
copy_shape_infer(node)


class MVNOnnx(Op):
op = 'MVNOnnx'
enabled = False

def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'kind': 'op',
'type': None,
'op': self.op,
'version': None,
'eps': None,
'eps_mode': None,
'normalize_variance': None,
'axes': None,
'in_ports_count': 1,
'out_ports_count': 1,
'infer': None
}, attrs)


class MVNCaffe(Op):
op = 'MVNCaffe'
enabled = False
Expand All @@ -81,8 +101,8 @@ def __init__(self, graph: Graph, attrs: dict):
'op': self.op,
'version': None,
'eps': 1e-9,
'normalize_variance': 1,
'across_channels': 0,
'normalize_variance': None,
'across_channels': None,
'in_ports_count': 1,
'out_ports_count': 1,
'infer': None
Expand Down
4 changes: 2 additions & 2 deletions model-optimizer/mo/utils/unittest/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -359,7 +359,7 @@ def connect(first_tensor_name, second_tensor_name, skip_data=False, front_phase=
second_op_name, in_port = get_name_and_port(second_tensor_name)

if skip_data:
return [(first_op_name + '_d', second_op_name, {'in': in_port})]
return [(first_op_name + '_d', second_op_name, {'out': out_port, 'in': in_port})]
if front_phase:
return [(first_op_name, second_op_name, {'out': out_port, 'in': in_port})]
return [
Expand Down