Skip to content

Commit

Permalink
[ MO Interpolate ] Fixing broken model reshape-ability (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgenya Stepyreva authored May 29, 2020
1 parent 5cc8114 commit e290b14
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 0 deletions.
1 change: 1 addition & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ extensions/back/GatherNormalizer.py
extensions/back/GroupedConvWeightsNormalize.py
extensions/back/I64ToI32.py
extensions/back/insert_compatibility_l2normalization.py
extensions/back/InterpolateReshape.py
extensions/back/InterpolateToInterpOrResample.py
extensions/back/kaldi_remove_memory_output.py
extensions/back/LeakyReLUMutation.py
Expand Down
154 changes: 154 additions & 0 deletions model-optimizer/extensions/back/InterpolateReshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

from extensions.ops.elementwise import Mul
from extensions.ops.gather import Gather
from mo.back.replacement import BackReplacementPattern
from mo.front.caffe.extractors.utils import get_canonical_axis_index
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.ops.shape import Shape


class InterpolateConcat(BackReplacementPattern):
"""
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs
BEFORE:
input Const
shape=[1, 3, 30, 40] value=[60, 160]
\ /
Interpolate(axes=(2, 3)) input_1
shape=[1, 3, 60, 160] shape=[1, 4, 60, 160]
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
AFTER:
input
shape=[1, 3, 30, 40] input_1
| shape=[1, 4, 60, 160]
| / |
| ShapeOf |
| | |
| Gather |
| indices=(2, 3); axis=0 |
\ | |
Interpolate(axes=(2, 3)) |
shape=[1, 3, 60, 160] |
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
"""
enabled = True
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops]
force_shape_inference = True
id = 'reshape_interpolate_through_concat'

@staticmethod
def make_interpolate_reshapeable(interpolate, concat):
assert interpolate.soft_get('type') == 'Interpolate'
assert concat.soft_get('type') == 'Concat'

output_shape = interpolate.out_port(0).data.get_shape()

interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in interpolate.axes]
concat_axis = get_canonical_axis_index(output_shape, concat.axis)
if concat_axis in interp_axes:
return

concat_srcs = [port.get_source() for port in concat.in_ports().values()]
non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate']
if len(non_interp_concat_srcs) == 0:
return

graph = interpolate.graph
src = non_interp_concat_srcs[0]

shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node()
shape.in_port(0).connect(src)
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)},
{'name': shape.name + '/Gathered'}, shape)
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))

def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
continue
dsts = interpolate.out_port(0).get_destinations()
if len(dsts) == 1 and dsts[0].node.soft_get('type') == 'Concat':
self.make_interpolate_reshapeable(interpolate, dsts[0].node)


class InterpolateReshapeWA(BackReplacementPattern):
"""
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph.
WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation
BEFORE:
input Const
shape=[1, 3, 30, 40] value=[60, 160]
\ /
Interpolate(axes=(2, 3))
shape=[1, 3, 60, 160]
AFTER:
input
shape=[1, 3, 30, 40]
| \
| ShapeOf
| |
| Gather Const
| indices=(2, 3); axis=0 value=[2, 4]
| \ /
| Multiply
| /
Interpolate(axes=(2, 3))
shape=[1, 3, 60, 160]
"""
enabled = False
graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops]
force_shape_inference = True
id = 'reshape_interpolate_wa'

def run_after(self):
return [InterpolateConcat]

@staticmethod
def make_interpolate_reshapeable(interpolate):
assert interpolate.soft_get('type') == 'Interpolate'
axes = interpolate.axes
input_shape = interpolate.in_port(0).data.get_shape()
output_shape = interpolate.out_port(0).data.get_shape()
if not np.all(np.remainder(output_shape, input_shape) == 0) and \
not np.all(np.remainder(input_shape, output_shape) == 0):
return
graph = interpolate.graph
name = interpolate.soft_get('name', interpolate.id)
shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node()
shape.in_port(0).connect(interpolate.in_port(0).get_source())
gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)},
{'name': shape.name + '/Gathered'}, shape)
multipliers = output_shape[axes] / input_shape[axes]
mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather)
interpolate.in_port(1).get_connection().set_source(mul.out_port(0))

def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const':
self.make_interpolate_reshapeable(interpolate)
97 changes: 97 additions & 0 deletions model-optimizer/extensions/back/InterpolateReshape_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
from argparse import Namespace

import numpy as np

from extensions.back.InterpolateReshape import InterpolateReshapeWA, InterpolateConcat
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \
connect_data

nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}),
**valued_const_with_data('out_shape', np.array([60, 160])),

**regular_op_with_shaped_data('interpolate', [1, 3, 60, 160], {'type': 'Interpolate', 'axes': [2, 3]}),

**regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}),
**valued_const_with_data('indices', np.array([2, 3])),
**valued_const_with_data('axis', np.array(0)),
**regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}),

**valued_const_with_data('multiplier', np.array([2, 4])),
**regular_op_with_shaped_data('mul', [2], {'type': 'Multiply'}),

**regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}),
**regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}),

**result(),
}


class TestInterpolateReshapeWA(unittest.TestCase):
def test_interpolate_reshape_graph_comparison(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', 'output'),
], nodes_with_edges_only=True)
InterpolateReshapeWA().find_and_replace_pattern(graph)
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect_data('placeholder', 'shape'),
*connect('shape', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', '0:mul'),
*connect('multiplier', '1:mul'),
*connect('mul', '1:interpolate'),
*connect('interpolate', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)


class TestInterpolateConcat(unittest.TestCase):
def test_interpolate_concat_reshape_graph_comparison(self):
graph = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('out_shape', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
InterpolateConcat().find_and_replace_pattern(graph)
graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
graph.clean_up()
graph_ref = build_graph(nodes, [
*connect('placeholder', '0:interpolate'),
*connect('placeholder_1', 'shape'),
*connect('shape', '0:gather'),
*connect('indices', '1:gather'),
*connect('axis', '2:gather'),
*connect('gather', '1:interpolate'),
*connect('interpolate', '0:concat'),
*connect_data('placeholder_1', '1:concat'),
*connect('concat', 'output'),
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

0 comments on commit e290b14

Please sign in to comment.