diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index e4da215ad3cda9..2d19e07fbcb815 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -29,6 +29,7 @@ extensions/back/GatherNormalizer.py extensions/back/GroupedConvWeightsNormalize.py extensions/back/I64ToI32.py extensions/back/insert_compatibility_l2normalization.py +extensions/back/InterpolateReshape.py extensions/back/InterpolateToInterpOrResample.py extensions/back/kaldi_remove_memory_output.py extensions/back/LeakyReLUMutation.py diff --git a/model-optimizer/extensions/back/InterpolateReshape.py b/model-optimizer/extensions/back/InterpolateReshape.py new file mode 100644 index 00000000000000..e1ecbebbcd8a9b --- /dev/null +++ b/model-optimizer/extensions/back/InterpolateReshape.py @@ -0,0 +1,154 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import numpy as np + +from extensions.ops.elementwise import Mul +from extensions.ops.gather import Gather +from mo.back.replacement import BackReplacementPattern +from mo.front.caffe.extractors.utils import get_canonical_axis_index +from mo.front.common.partial_infer.utils import int64_array +from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input +from mo.graph.graph import Graph +from mo.ops.shape import Shape + + +class InterpolateConcat(BackReplacementPattern): + """ + Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph using the following Concat inputs + + BEFORE: + input Const + shape=[1, 3, 30, 40] value=[60, 160] + \ / + Interpolate(axes=(2, 3)) input_1 + shape=[1, 3, 60, 160] shape=[1, 4, 60, 160] + \ / + Concat(axis=1) + shape=[1, 7, 60, 160] + AFTER: + input + shape=[1, 3, 30, 40] input_1 + | shape=[1, 4, 60, 160] + | / | + | ShapeOf | + | | | + | Gather | + | indices=(2, 3); axis=0 | + \ | | + Interpolate(axes=(2, 3)) | + shape=[1, 3, 60, 160] | + \ / + Concat(axis=1) + shape=[1, 7, 60, 160] + + """ + enabled = True + graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops] + force_shape_inference = True + id = 'reshape_interpolate_through_concat' + + @staticmethod + def make_interpolate_reshapeable(interpolate, concat): + assert interpolate.soft_get('type') == 'Interpolate' + assert concat.soft_get('type') == 'Concat' + + output_shape = interpolate.out_port(0).data.get_shape() + + interp_axes = [get_canonical_axis_index(output_shape, axis) for axis in interpolate.axes] + concat_axis = get_canonical_axis_index(output_shape, concat.axis) + if concat_axis in interp_axes: + return + + concat_srcs = [port.get_source() for port in concat.in_ports().values()] + non_interp_concat_srcs = [src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate'] + if len(non_interp_concat_srcs) == 0: + return + + graph = interpolate.graph + src = non_interp_concat_srcs[0] + + shape = Shape(graph, {'name': src.node.soft_get('name', src.node.id) + '/Shape'}).create_node() + shape.in_port(0).connect(src) + gather = create_op_with_const_inputs(graph, Gather, {1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0)}, + {'name': shape.name + '/Gathered'}, shape) + interpolate.in_port(1).get_connection().set_source(gather.out_port(0)) + + def find_and_replace_pattern(self, graph: Graph): + for interpolate in graph.get_op_nodes(type='Interpolate'): + if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const': + continue + dsts = interpolate.out_port(0).get_destinations() + if len(dsts) == 1 and dsts[0].node.soft_get('type') == 'Concat': + self.make_interpolate_reshapeable(interpolate, dsts[0].node) + + +class InterpolateReshapeWA(BackReplacementPattern): + """ + Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph. + WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation + + BEFORE: + input Const + shape=[1, 3, 30, 40] value=[60, 160] + \ / + Interpolate(axes=(2, 3)) + shape=[1, 3, 60, 160] + + AFTER: + input + shape=[1, 3, 30, 40] + | \ + | ShapeOf + | | + | Gather Const + | indices=(2, 3); axis=0 value=[2, 4] + | \ / + | Multiply + | / + Interpolate(axes=(2, 3)) + shape=[1, 3, 60, 160] + """ + enabled = False + graph_condition = [lambda graph: graph.graph['cmd_params'].keep_shape_ops] + force_shape_inference = True + id = 'reshape_interpolate_wa' + + def run_after(self): + return [InterpolateConcat] + + @staticmethod + def make_interpolate_reshapeable(interpolate): + assert interpolate.soft_get('type') == 'Interpolate' + axes = interpolate.axes + input_shape = interpolate.in_port(0).data.get_shape() + output_shape = interpolate.out_port(0).data.get_shape() + if not np.all(np.remainder(output_shape, input_shape) == 0) and \ + not np.all(np.remainder(input_shape, output_shape) == 0): + return + graph = interpolate.graph + name = interpolate.soft_get('name', interpolate.id) + shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() + shape.in_port(0).connect(interpolate.in_port(0).get_source()) + gather = create_op_with_const_inputs(graph, Gather, {1: np.array(axes, dtype=np.int32), 2: int64_array(0)}, + {'name': shape.name + '/Gathered'}, shape) + multipliers = output_shape[axes] / input_shape[axes] + mul = create_op_node_with_second_input(graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) + interpolate.in_port(1).get_connection().set_source(mul.out_port(0)) + + def find_and_replace_pattern(self, graph: Graph): + for interpolate in graph.get_op_nodes(type='Interpolate'): + if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const': + self.make_interpolate_reshapeable(interpolate) diff --git a/model-optimizer/extensions/back/InterpolateReshape_test.py b/model-optimizer/extensions/back/InterpolateReshape_test.py new file mode 100644 index 00000000000000..f793a4b592fceb --- /dev/null +++ b/model-optimizer/extensions/back/InterpolateReshape_test.py @@ -0,0 +1,97 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest +from argparse import Namespace + +import numpy as np + +from extensions.back.InterpolateReshape import InterpolateReshapeWA, InterpolateConcat +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, valued_const_with_data, connect, \ + connect_data + +nodes = { + **regular_op_with_shaped_data('placeholder', [1, 3, 30, 40], {'type': 'Parameter'}), + **valued_const_with_data('out_shape', np.array([60, 160])), + + **regular_op_with_shaped_data('interpolate', [1, 3, 60, 160], {'type': 'Interpolate', 'axes': [2, 3]}), + + **regular_op_with_shaped_data('shape', [4], {'type': 'ShapeOf'}), + **valued_const_with_data('indices', np.array([2, 3])), + **valued_const_with_data('axis', np.array(0)), + **regular_op_with_shaped_data('gather', [2], {'type': 'Gather'}), + + **valued_const_with_data('multiplier', np.array([2, 4])), + **regular_op_with_shaped_data('mul', [2], {'type': 'Multiply'}), + + **regular_op_with_shaped_data('placeholder_1', [1, 3, 60, 160], {'type': 'Parameter'}), + **regular_op_with_shaped_data('concat', [1, 7, 60, 160], {'type': 'Concat', 'axis': 1}), + + **result(), +} + + +class TestInterpolateReshapeWA(unittest.TestCase): + def test_interpolate_reshape_graph_comparison(self): + graph = build_graph(nodes, [ + *connect('placeholder', '0:interpolate'), + *connect('out_shape', '1:interpolate'), + *connect('interpolate', 'output'), + ], nodes_with_edges_only=True) + InterpolateReshapeWA().find_and_replace_pattern(graph) + graph.graph['cmd_params'] = Namespace(keep_shape_ops=True) + graph.clean_up() + graph_ref = build_graph(nodes, [ + *connect('placeholder', '0:interpolate'), + *connect_data('placeholder', 'shape'), + *connect('shape', '0:gather'), + *connect('indices', '1:gather'), + *connect('axis', '2:gather'), + *connect('gather', '0:mul'), + *connect('multiplier', '1:mul'), + *connect('mul', '1:interpolate'), + *connect('interpolate', 'output'), + ], nodes_with_edges_only=True) + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) + + +class TestInterpolateConcat(unittest.TestCase): + def test_interpolate_concat_reshape_graph_comparison(self): + graph = build_graph(nodes, [ + *connect('placeholder', '0:interpolate'), + *connect('out_shape', '1:interpolate'), + *connect('interpolate', '0:concat'), + *connect('placeholder_1', '1:concat'), + *connect('concat', 'output'), + ], nodes_with_edges_only=True) + InterpolateConcat().find_and_replace_pattern(graph) + graph.graph['cmd_params'] = Namespace(keep_shape_ops=True) + graph.clean_up() + graph_ref = build_graph(nodes, [ + *connect('placeholder', '0:interpolate'), + *connect('placeholder_1', 'shape'), + *connect('shape', '0:gather'), + *connect('indices', '1:gather'), + *connect('axis', '2:gather'), + *connect('gather', '1:interpolate'), + *connect('interpolate', '0:concat'), + *connect_data('placeholder_1', '1:concat'), + *connect('concat', 'output'), + ], nodes_with_edges_only=True) + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp)