-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ MO Interpolate ] Fixing broken model reshape-ability (#619)
- Loading branch information
Evgenya Stepyreva
authored
May 29, 2020
1 parent
5cc8114
commit e290b14
Showing
3 changed files
with
252 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import numpy as np | ||
|
||
from extensions.ops.elementwise import Mul | ||
from extensions.ops.gather import Gather | ||
from mo.back.replacement import BackReplacementPattern | ||
from mo.front.caffe.extractors.utils import get_canonical_axis_index | ||
from mo.front.common.partial_infer.utils import int64_array | ||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input | ||
from mo.graph.graph import Graph | ||
from mo.ops.shape import Shape | ||
|
||
|
||
class InterpolateConcat(BackReplacementPattern): | ||
""" | ||
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs | ||
BEFORE: | ||
input Const | ||
shape=[1, 3, 30, 40] value=[60, 160] | ||
\ / | ||
Interpolate(axes=(2, 3)) input_1 | ||
shape=[1, 3, 60, 160] shape=[1, 4, 60, 160] | ||
\ / | ||
Concat(axis=1) | ||
shape=[1, 7, 60, 160] | ||
AFTER: | ||
input | ||
shape=[1, 3, 30, 40] input_1 | ||
| shape=[1, 4, 60, 160] | ||
| / | | ||
| ShapeOf | | ||
| | | | ||
| Gather | | ||
| indices=(2, 3); axis=0 | | ||
\ | | | ||
Interpolate(axes=(2, 3)) | | ||
shape=[1, 3, 60, 160] | | ||
\ / | ||
Concat(axis=1) | ||
shape=[1, 7, 60, 160] | ||
""" | ||
enabled = True | ||
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops] | ||
force_shape_inference = True | ||
id = 'reshape_interpolate_through_concat' | ||
|
||
@staticmethod | ||
def make_interpolate_reshapeable(interpolate, concat): | ||
assert interpolate.soft_get('type') == 'Interpolate' | ||
assert concat.soft_get('type') == 'Concat' | ||
|
||
output_shape = interpolate.out_port(0).data.get_shape() | ||
|
||
interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in interpolate.axes] | ||
concat_axis = get_canonical_axis_index(output_shape, concat.axis) | ||
if concat_axis in interp_axes: | ||
return | ||
|
||
concat_srcs = [port.get_source() for port in concat.in_ports().values()] | ||
non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate'] | ||
if len(non_interp_concat_srcs) == 0: | ||
return | ||
|
||
graph = interpolate.graph | ||
src = non_interp_concat_srcs[0] | ||
|
||
shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node() | ||
shape.in_port(0).connect(src) | ||
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)}, | ||
{'name': shape.name + '/Gathered'}, shape) | ||
interpolate.in_port(1).get_connection().set_source(gather.out_port(0)) | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
for interpolate in graph.get_op_nodes(type='Interpolate'): | ||
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const': | ||
continue | ||
dsts = interpolate.out_port(0).get_destinations() | ||
if len(dsts) == 1 and dsts[0].node.soft_get('type') == 'Concat': | ||
self.make_interpolate_reshapeable(interpolate, dsts[0].node) | ||
|
||
|
||
class InterpolateReshapeWA(BackReplacementPattern): | ||
""" | ||
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph. | ||
WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation | ||
BEFORE: | ||
input Const | ||
shape=[1, 3, 30, 40] value=[60, 160] | ||
\ / | ||
Interpolate(axes=(2, 3)) | ||
shape=[1, 3, 60, 160] | ||
AFTER: | ||
input | ||
shape=[1, 3, 30, 40] | ||
| \ | ||
| ShapeOf | ||
| | | ||
| Gather Const | ||
| indices=(2, 3); axis=0 value=[2, 4] | ||
| \ / | ||
| Multiply | ||
| / | ||
Interpolate(axes=(2, 3)) | ||
shape=[1, 3, 60, 160] | ||
""" | ||
enabled = False | ||
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops] | ||
force_shape_inference = True | ||
id = 'reshape_interpolate_wa' | ||
|
||
def run_after(self): | ||
return [InterpolateConcat] | ||
|
||
@staticmethod | ||
def make_interpolate_reshapeable(interpolate): | ||
assert interpolate.soft_get('type') == 'Interpolate' | ||
axes = interpolate.axes | ||
input_shape = interpolate.in_port(0).data.get_shape() | ||
output_shape = interpolate.out_port(0).data.get_shape() | ||
if not np.all(np.remainder(output_shape, input_shape) == 0) and \ | ||
not np.all(np.remainder(input_shape, output_shape) == 0): | ||
return | ||
graph = interpolate.graph | ||
name = interpolate.soft_get('name', interpolate.id) | ||
shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() | ||
shape.in_port(0).connect(interpolate.in_port(0).get_source()) | ||
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)}, | ||
{'name': shape.name + '/Gathered'}, shape) | ||
multipliers = output_shape[axes] / input_shape[axes] | ||
mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) | ||
interpolate.in_port(1).get_connection().set_source(mul.out_port(0)) | ||
|
||
def find_and_replace_pattern(self, graph: Graph): | ||
for interpolate in graph.get_op_nodes(type='Interpolate'): | ||
if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const': | ||
self.make_interpolate_reshapeable(interpolate) |
97 changes: 97 additions & 0 deletions
97
model-optimizer/extensions/back/InterpolateReshape_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Copyright (C) 2018-2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
from argparse import Namespace | ||
|
||
import numpy as np | ||
|
||
from extensions.back.InterpolateReshape import InterpolateReshapeWA, InterpolateConcat | ||
from mo.utils.ir_engine.compare_graphs import compare_graphs | ||
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \ | ||
connect_data | ||
|
||
nodes = { | ||
**regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}), | ||
**valued_const_with_data('out_shape', np.array([60, 160])), | ||
|
||
**regular_op_with_shaped_data('interpolate', [1, 3, 60, 160], {'type': 'Interpolate', 'axes': [2, 3]}), | ||
|
||
**regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}), | ||
**valued_const_with_data('indices', np.array([2, 3])), | ||
**valued_const_with_data('axis', np.array(0)), | ||
**regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}), | ||
|
||
**valued_const_with_data('multiplier', np.array([2, 4])), | ||
**regular_op_with_shaped_data('mul', [2], {'type': 'Multiply'}), | ||
|
||
**regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}), | ||
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}), | ||
|
||
**result(), | ||
} | ||
|
||
|
||
class TestInterpolateReshapeWA(unittest.TestCase): | ||
def test_interpolate_reshape_graph_comparison(self): | ||
graph = build_graph(nodes, [ | ||
*connect('placeholder', '0:interpolate'), | ||
*connect('out_shape', '1:interpolate'), | ||
*connect('interpolate', 'output'), | ||
], nodes_with_edges_only=True) | ||
InterpolateReshapeWA().find_and_replace_pattern(graph) | ||
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True) | ||
graph.clean_up() | ||
graph_ref = build_graph(nodes, [ | ||
*connect('placeholder', '0:interpolate'), | ||
*connect_data('placeholder', 'shape'), | ||
*connect('shape', '0:gather'), | ||
*connect('indices', '1:gather'), | ||
*connect('axis', '2:gather'), | ||
*connect('gather', '0:mul'), | ||
*connect('multiplier', '1:mul'), | ||
*connect('mul', '1:interpolate'), | ||
*connect('interpolate', 'output'), | ||
], nodes_with_edges_only=True) | ||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) | ||
self.assertTrue(flag, resp) | ||
|
||
|
||
class TestInterpolateConcat(unittest.TestCase): | ||
def test_interpolate_concat_reshape_graph_comparison(self): | ||
graph = build_graph(nodes, [ | ||
*connect('placeholder', '0:interpolate'), | ||
*connect('out_shape', '1:interpolate'), | ||
*connect('interpolate', '0:concat'), | ||
*connect('placeholder_1', '1:concat'), | ||
*connect('concat', 'output'), | ||
], nodes_with_edges_only=True) | ||
InterpolateConcat().find_and_replace_pattern(graph) | ||
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True) | ||
graph.clean_up() | ||
graph_ref = build_graph(nodes, [ | ||
*connect('placeholder', '0:interpolate'), | ||
*connect('placeholder_1', 'shape'), | ||
*connect('shape', '0:gather'), | ||
*connect('indices', '1:gather'), | ||
*connect('axis', '2:gather'), | ||
*connect('gather', '1:interpolate'), | ||
*connect('interpolate', '0:concat'), | ||
*connect_data('placeholder_1', '1:concat'), | ||
*connect('concat', 'output'), | ||
], nodes_with_edges_only=True) | ||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) | ||
self.assertTrue(flag, resp) |