diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 786eb6a4471907..c4ef5ae033d1a0 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -66,6 +66,7 @@ #include "transformations/handle_transposes_around_matmul.hpp" #include "transformations/decompose_2d_conv.hpp" #include "transformations/convert_padded2valid_conv.hpp" +#include "transformations/insert_reshape_around_matmul.hpp" #include "transformations/op_conversions/lstm_cell_decomposition.hpp" #include "transformations/remove_single_input_concat.hpp" @@ -730,10 +731,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp index 9591bd0fc6cef9..e0a009a9e926f1 100644 --- a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp +++ b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp @@ -6,31 +6,33 @@ #include -#include -#include -#include +#include #include +#include +#include +#include +#include #include "gna_plugin_log.hpp" #include "backend/gna_limitations.hpp" -using namespace GNAPluginNS; +namespace GNAPluginNS { NGRAPH_RTTI_DEFINITION(HandleTransposesAroundMatMul, "HandleTransposesAroundMatMul", 0); NGRAPH_RTTI_DEFINITION(HandleTransposeBeforeMatMul, "HandleTransposeBeforeMatMul", 0); NGRAPH_RTTI_DEFINITION(HandleTransposeAfterMatMul, "HandleTransposeAfterMatMul", 0); -static void ReplaceTransposeWithReshape(std::shared_ptr transpose_node) { +void ReplaceTransposeWithReshape(std::shared_ptr transpose_node) { auto shape = transpose_node->get_output_shape(0); - auto reshape_const = std::make_shared(ngraph::element::Type_t::i64, + auto reshape_const = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{shape.size()}, shape); - auto reshape_node = std::make_shared(transpose_node->input_value(0), reshape_const, false); - reshape_node->set_friendly_name(transpose_node->get_friendly_name() + "/reshape"); + auto reshape_node = std::make_shared(transpose_node->input_value(0), reshape_const, false); + reshape_node->set_friendly_name(transpose_node->get_friendly_name()); ngraph::copy_runtime_info(transpose_node, reshape_node); transpose_node->output(0).replace(reshape_node->output(0)); } -static void InsertTranspose(std::shared_ptr prev_node, const std::string& base_name) { +void InsertTranspose(std::shared_ptr prev_node, const std::string& base_name) { auto consumers = prev_node->output(0).get_target_inputs(); const auto orig_shape = prev_node->get_output_shape(0); std::vector transpose_ids; @@ -44,13 +46,13 @@ static void InsertTranspose(std::shared_ptr prev_node, const std:: std::iota(std::begin(permute_order), std::end(permute_order), 0); std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]); - auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order); - auto transpose = std::make_shared(prev_node, transpose_order); + auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order); + auto transpose = std::make_shared(prev_node, transpose_order); transpose->set_friendly_name(base_name + "/in_transpose"); - auto reshapeConstAfter = std::make_shared(ngraph::element::Type_t::i64, + auto reshapeConstAfter = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{orig_shape.size()}, orig_shape); - auto reshapeAfter = std::make_shared(transpose, reshapeConstAfter, false); + auto reshapeAfter = std::make_shared(transpose, reshapeConstAfter, false); reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose"); ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter}); @@ -59,74 +61,102 @@ static void InsertTranspose(std::shared_ptr prev_node, const std:: } } +static bool VerifyReshape(const ngraph::Output& reshape_out) { + auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0); + auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0); + return in_shape[0] != out_shape[0]; +} + HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() { - auto reshape = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), - ngraph::pattern::any_input()}, VerifyReshape()); - auto transpose = ngraph::pattern::wrap_type({reshape, + auto constant = ngraph::pattern::wrap_type(); + auto fq = ngraph::pattern::wrap_type({constant, ngraph::pattern::any_input(), + ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()}); + auto reshape = ngraph::pattern::wrap_type({}, VerifyReshape); + auto transpose = ngraph::pattern::wrap_type({reshape, + ngraph::pattern::any_input()}); + auto matmul1 = ngraph::pattern::wrap_type({ + std::make_shared(ngraph::OutputVector{reshape, transpose}), ngraph::pattern::any_input()}); - auto matmul_input = std::make_shared(ngraph::OutputVector{reshape, transpose}); - auto matmul1 = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}); - auto matmul2 = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), matmul_input}); + auto matmul2 = ngraph::pattern::wrap_type({ + std::make_shared(ngraph::OutputVector{constant, fq}), + std::make_shared(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})}); auto matmul = std::make_shared(ngraph::OutputVector{matmul1, matmul2}); - ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { - const auto& pattern_map = m.get_pattern_value_map(); - auto transpose_it = pattern_map.find(transpose); - if (transpose_it != std::end(pattern_map)) { - ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr()); - } else { + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { + const auto& pattern_map = matcher.get_pattern_value_map(); + auto matmul_iter = pattern_map.find(matmul1); + if (matmul_iter == std::end(pattern_map) && + (matmul_iter = pattern_map.find(matmul2)) == std::end(pattern_map)) { + return false; + } + + auto transpose_reshape_it = pattern_map.find(transpose); + if (transpose_reshape_it != std::end(pattern_map)) { + ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr()); + } else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) { auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr(); - if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false; - auto matmul_it = pattern_map.find(matmul1); - auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2); - InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name()); + if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) { + auto matmul_node = matmul_iter->second.get_node_shared_ptr(); + InsertTranspose(reshape_node, matmul_node->get_friendly_name()); + } + } + + auto iter = pattern_map.find(fq); + if (iter != pattern_map.end() || + (iter = pattern_map.find(constant)) != pattern_map.end()) { + auto prev_node = iter->second.get_node_shared_ptr(); + if (!GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) return false; + auto matmul_node = iter->second.get_node_shared_ptr(); + InsertTranspose(prev_node, matmul_node->get_friendly_name()); } return true; }; - auto m = std::make_shared(matmul, "HandleTransposeBeforeMatMul"); - this->register_matcher(m, callback); + auto matcher = std::make_shared(matmul, "HandleTransposeBeforeMatMul"); + this->register_matcher(matcher, callback); } HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() { - auto matmul = ngraph::pattern::wrap_type(); - auto fq = ngraph::pattern::wrap_type({matmul, ngraph::pattern::any_input(), + auto matmul = ngraph::pattern::wrap_type(); + auto add_left = ngraph::pattern::wrap_type({matmul, ngraph::pattern::any_input()}); + auto add_right = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), matmul}); + auto fq_input = std::make_shared(ngraph::OutputVector{matmul, add_left, add_right}); + auto fq = ngraph::pattern::wrap_type({fq_input, ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()}); - auto transpose_input = std::make_shared(ngraph::OutputVector{matmul, fq}); - auto transpose = ngraph::pattern::wrap_type({transpose_input, ngraph::pattern::any_input()}); + auto transpose_input = std::make_shared(ngraph::OutputVector{fq_input, fq}); + auto transpose = ngraph::pattern::wrap_type({transpose_input, ngraph::pattern::any_input()}); auto reshape_input = std::make_shared(ngraph::OutputVector{transpose_input, transpose}); - auto reshape = ngraph::pattern::wrap_type({reshape_input, - ngraph::pattern::any_input()}, VerifyReshape()); + auto reshape = ngraph::pattern::wrap_type( + {reshape_input, ngraph::pattern::any_input()}, VerifyReshape); - ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { - const auto& pattern_map = m.get_pattern_value_map(); + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { + const auto& pattern_map = matcher.get_pattern_value_map(); auto transpose_it = pattern_map.find(transpose); if (transpose_it != std::end(pattern_map)) { ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr()); } else { auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr(); - if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false; - auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr(); - InsertTranspose(matmul_node, matmul_node->get_friendly_name()); + if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false; + auto iter = pattern_map.find(fq); + if (iter == pattern_map.end() && + (iter = pattern_map.find(add_left)) == pattern_map.end() && + (iter = pattern_map.find(add_right)) == pattern_map.end() && + (iter = pattern_map.find(matmul)) == pattern_map.end()) { + return false; + } + auto node = iter->second.get_node_shared_ptr(); + InsertTranspose(node, node->get_friendly_name()); } return true; }; - auto m = std::make_shared(reshape, "HandleTransposeAfterMatMul"); - this->register_matcher(m, callback); -} - -bool VerifyReshape::operator()(const ngraph::Output& reshape_out) const { - auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0); - auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0); - - // Check if Reshape changes the final 2d shape of Affine primitive - in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), in_shape.end()); - out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end()); - return in_shape != out_shape; + auto matcher = std::make_shared(reshape, "HandleTransposeAfterMatMul"); + this->register_matcher(matcher, callback); } HandleTransposesAroundMatMul::HandleTransposesAroundMatMul() { add_matcher(); add_matcher(); } + +} // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.hpp b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.hpp index 2601655f77fe9e..c9e41b641b93f3 100644 --- a/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.hpp +++ b/inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.hpp @@ -8,10 +8,6 @@ namespace GNAPluginNS { -struct VerifyReshape { - bool operator()(const ngraph::Output& reshape_out) const; -}; - /** * @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape * before MatMul which changes the batch size: @@ -48,16 +44,16 @@ class HandleTransposeBeforeMatMul : public ngraph::pass::MatcherPass { * | | * [1, A*B] [1, A*B] */ -class HandleTransposeAfterMatMul : public ngraph::pass::MatcherPass { +class HandleTransposeAfterMatMul: public ngraph::pass::MatcherPass { public: - NGRAPH_RTTI_DECLARATION; - HandleTransposeAfterMatMul(); + NGRAPH_RTTI_DECLARATION; + HandleTransposeAfterMatMul(); }; -class HandleTransposesAroundMatMul: public ngraph::pass::GraphRewrite { +class HandleTransposesAroundMatMul : public ngraph::pass::GraphRewrite { public: - NGRAPH_RTTI_DECLARATION; - HandleTransposesAroundMatMul(); + NGRAPH_RTTI_DECLARATION; + HandleTransposesAroundMatMul(); }; } // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/transformations/insert_reshape_around_matmul.cpp b/inference-engine/src/gna_plugin/transformations/insert_reshape_around_matmul.cpp new file mode 100644 index 00000000000000..9d82d7b402d68e --- /dev/null +++ b/inference-engine/src/gna_plugin/transformations/insert_reshape_around_matmul.cpp @@ -0,0 +1,237 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/insert_reshape_around_matmul.hpp" +#include +#include +#include +#include +#include +#include + +#include "gna_plugin_log.hpp" + +namespace GNAPluginNS { + +NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmul, "InsertReshapeAroundMatmul", 0); +NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithAdd, "InsertReshapeAroundMatmulWithAdd", 0); +NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithFq, "InsertReshapeAroundMatmulWithFq", 0); +NGRAPH_RTTI_DEFINITION(InsertReshapeAroundMatmulWithTranspose, "InsertReshapeAroundMatmulWithTranspose", 0); + +static bool InsertReshape( + ngraph::pattern::Matcher &matcher, + const std::shared_ptr& input, + const std::shared_ptr& matmul1, + const std::shared_ptr& matmul2, + const std::shared_ptr& add1 = nullptr, + const std::shared_ptr& add2 = nullptr, + const std::shared_ptr& fake_quantize2 = nullptr, + const std::shared_ptr& transpose = nullptr) { + const auto& pattern_map = matcher.get_pattern_value_map(); + size_t matmul_input_index = 1; + auto iter = pattern_map.find(matmul1); + if (iter == pattern_map.end()) { + iter = pattern_map.find(matmul2); + if ((iter = pattern_map.find(matmul2)) == pattern_map.end()) { + return false; + } + + matmul_input_index = 0; + } + + std::shared_ptr matmul_node = iter->second.get_node_shared_ptr(); + auto matmul_node_shape = matmul_node->get_output_shape(0); + if ((iter = pattern_map.find(input)) == std::end(pattern_map)) { + return false; + } + + std::shared_ptr first_node = iter->second.get_node_shared_ptr(); + auto reshape_input_node = std::dynamic_pointer_cast(first_node); + bool need_reshape_before = !reshape_input_node || reshape_input_node->get_output_shape(0).size() != 2; + if (need_reshape_before) { + auto input_shape = first_node->get_output_shape(0); + std::vector before_shape(2, 1); + std::copy_if(input_shape.begin(), input_shape.end(), before_shape.begin(), [](size_t e) { return e > 1; }); + auto reshape_before_node = std::make_shared(first_node, + std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{before_shape.size()}, before_shape), false); + reshape_before_node->set_friendly_name(matmul_node->get_friendly_name() + "/reshape_before_matmul"); + ngraph::copy_runtime_info(first_node, reshape_before_node); + matmul_node->input(matmul_input_index).replace_source_output(reshape_before_node->output(0)); + } + + std::shared_ptr last_node; + iter = pattern_map.find(transpose); + if (iter == pattern_map.end() && + (iter = pattern_map.find(fake_quantize2)) == pattern_map.end() && + (iter = pattern_map.find(add1)) == pattern_map.end() && + (iter = pattern_map.find(add2)) == pattern_map.end()) { + last_node = matmul_node; + } else { + last_node = iter->second.get_node_shared_ptr(); + } + + auto consumers = last_node->output(0).get_target_inputs(); + auto last_node_shape = last_node->get_output_shape(0); + bool need_reshape_after = false; + for (auto consumer : consumers) { + auto reshape_output_node = dynamic_cast(consumer.get_node()); + if (!reshape_output_node || reshape_output_node->get_output_shape(0).size() != last_node_shape.size()) { + need_reshape_after = true; + break; + } + } + + if (need_reshape_after) { + auto reshape_after_node = std::make_shared(last_node, + std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{last_node_shape.size()}, last_node_shape), false); + reshape_after_node->set_friendly_name(last_node->get_friendly_name()); + ngraph::copy_runtime_info(last_node, reshape_after_node); + for (auto consumer : consumers) { + consumer.replace_source_output(reshape_after_node); + } + } + + return need_reshape_before || need_reshape_after; +} + +static std::shared_ptr CreateMatmulPattern( + std::shared_ptr& input, + std::shared_ptr& matmul1, + std::shared_ptr& matmul2, + const ngraph::pattern::op::ValuePredicate& pred = [](const ngraph::Output& output) { return true; }) { + auto constant = ngraph::pattern::wrap_type(); + auto fake_quantize = ngraph::pattern::wrap_type({constant, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); + input = ngraph::pattern::any_input([](const ngraph::Output& node) { + auto shape = node.get_node_shared_ptr()->get_output_shape(0); + return shape.size() > 2 && std::count_if(shape.begin(), shape.end(), [](size_t e) { return e > 1; }) <= 2; }); + matmul1 = ngraph::pattern::wrap_type({matmul_input, input}, pred); + matmul2 = ngraph::pattern::wrap_type({input, matmul_input}, pred); + return std::make_shared(ngraph::OutputVector{matmul1, matmul2}); +} + +InsertReshapeAroundMatmul::InsertReshapeAroundMatmul() { + MATCHER_SCOPE(InsertReshapeAroundMatmul); + + auto pred = [](const ngraph::Output& node) { + const auto& outputs = node.get_node_shared_ptr()->outputs(); + const auto& inputs = outputs[0].get_target_inputs(); + if (inputs.empty()) { + return true; + } + + auto next_node = inputs.begin()->get_node(); + return outputs.size() != 1 || + !dynamic_cast(next_node) && + !dynamic_cast(next_node) && + !dynamic_cast(next_node); + }; + + std::shared_ptr input; + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmulPattern(input, matmul1, matmul2, pred); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { + return InsertReshape(matcher, input, matmul1, matmul2); + }; + + auto matcher = std::make_shared(matmul, "InsertReshapeAroundMatmul"); + this->register_matcher(matcher, callback); +} + +InsertReshapeAroundMatmulWithAdd::InsertReshapeAroundMatmulWithAdd() { + MATCHER_SCOPE(InsertReshapeAroundMatmulWithAdd); + + auto pred = [](const ngraph::Output& node) { + const auto& outputs = node.get_node_shared_ptr()->outputs(); + const auto& inputs = outputs[0].get_target_inputs(); + if (inputs.empty()) { + return true; + } + + auto next_node = inputs.begin()->get_node(); + return outputs.size() != 1 || + !dynamic_cast(next_node) && + !dynamic_cast(next_node); + }; + + std::shared_ptr input; + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmulPattern(input, matmul1, matmul2); + auto add_input = ngraph::pattern::any_input(); + auto add1 = ngraph::pattern::wrap_type({matmul, add_input}, pred); + auto add2 = ngraph::pattern::wrap_type({add_input, matmul}, pred); + auto add = std::make_shared(ngraph::OutputVector{add1, add2}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { + return InsertReshape(matcher, input, matmul1, matmul2, add1, add2); + }; + + auto matcher = std::make_shared(add, "InsertReshapeAroundMatmulWithAdd"); + this->register_matcher(matcher, callback); +} + +InsertReshapeAroundMatmulWithFq::InsertReshapeAroundMatmulWithFq() { + MATCHER_SCOPE(InsertReshapeAroundMatmulWithFq); + + std::shared_ptr input; + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmulPattern(input, matmul1, matmul2); + auto add_input = ngraph::pattern::any_input(); + auto add1 = ngraph::pattern::wrap_type({matmul, add_input}); + auto add2 = ngraph::pattern::wrap_type({add_input, matmul}); + auto fq_input = std::make_shared(ngraph::OutputVector{matmul, add1, add2}); + auto fake_quantize2 = ngraph::pattern::wrap_type({fq_input, ngraph::pattern::any_input(), + ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()}, + [](const ngraph::Output& node) { + const auto& outputs = node.get_node_shared_ptr()->outputs(); + const auto& inputs = outputs[0].get_target_inputs(); + if (inputs.empty()) { + return true; + } + + auto next_node = inputs.begin()->get_node(); + return outputs.size() != 1 || + !dynamic_cast(next_node); + }); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { + return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2); + }; + + auto matcher = std::make_shared(fake_quantize2, "InsertReshapeAroundMatmulWithFq"); + this->register_matcher(matcher, callback); +} + +InsertReshapeAroundMatmulWithTranspose::InsertReshapeAroundMatmulWithTranspose() { + MATCHER_SCOPE(InsertReshapeAroundMatmulWithTranspose); + + std::shared_ptr input; + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmulPattern(input, matmul1, matmul2); + auto add_input = ngraph::pattern::any_input(); + auto add1 = ngraph::pattern::wrap_type({matmul, add_input}); + auto add2 = ngraph::pattern::wrap_type({add_input, matmul}); + auto fq_input = std::make_shared(ngraph::OutputVector{matmul, add1, add2}); + auto fake_quantize2 = ngraph::pattern::wrap_type({fq_input, ngraph::pattern::any_input(), + ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()}); + auto transpose_input = std::make_shared(ngraph::OutputVector{fq_input, fake_quantize2}); + auto transpose = ngraph::pattern::wrap_type({transpose_input, ngraph::pattern::any_input()}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) { + return InsertReshape(matcher, input, matmul1, matmul2, add1, add2, fake_quantize2, transpose); + }; + + auto matcher = std::make_shared(transpose, "InsertReshapeAroundMatmulWithTranspose"); + this->register_matcher(matcher, callback); +} +} // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/transformations/insert_reshape_around_matmul.hpp b/inference-engine/src/gna_plugin/transformations/insert_reshape_around_matmul.hpp new file mode 100644 index 00000000000000..02a728868c4d83 --- /dev/null +++ b/inference-engine/src/gna_plugin/transformations/insert_reshape_around_matmul.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef INSERT_RESHAPE_AROUND_MATMUL_HPP +#define INSERT_RESHAPE_AROUND_MATMUL_HPP + +#include + +namespace GNAPluginNS { + +// @brief Insert Reshapes from 3d/4d to 2d before MatMul and from 2d to 3d/4d after MatMul +class InsertReshapeAroundMatmul : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + InsertReshapeAroundMatmul(); +}; + +class InsertReshapeAroundMatmulWithAdd : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + InsertReshapeAroundMatmulWithAdd(); +}; + +class InsertReshapeAroundMatmulWithFq : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + InsertReshapeAroundMatmulWithFq(); +}; + +class InsertReshapeAroundMatmulWithTranspose : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + InsertReshapeAroundMatmulWithTranspose(); +}; + +} // namespace GNAPluginNS + +#endif // INSERT_RESHAPE_AROUND_MATMUL_HPP diff --git a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp index 2db8e10620c9dc..8b3dbe391eec60 100644 --- a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp +++ b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp @@ -2,31 +2,34 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include #include #include -#include -#include #include -#include #include -#include +#include +#include +#include +#include #include "gna_plugin_log.hpp" -using namespace GNAPluginNS; +namespace GNAPluginNS { NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0); NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0); NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0); -static void SwapAndTransposeInputs(std::shared_ptr matmul_node, - std::shared_ptr add, - std::shared_ptr bias, - std::shared_ptr fq) { +static void SwapAndTransposeInputs( + std::shared_ptr matmul_node, + std::shared_ptr add, + std::shared_ptr bias, + std::shared_ptr fq, + const std::string& last_layer_name) { auto create_transpose = [](ngraph::Output node, const std::string& transpose_name) -> std::shared_ptr { ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape(); @@ -52,15 +55,28 @@ static void SwapAndTransposeInputs(std::shared_ptr matmu std::shared_ptr old_root_node = matmul_node; if (bias != nullptr) { - // output of MatMul will be transposed comparing with original one, so the bias should be transposed too - if (bias->get_output_shape(0).size() > 1) { - bias = create_transpose(bias, bias->get_friendly_name() + "/transpose"); - new_ops.push_back(bias); - } - - new_matmul = std::make_shared(new_matmul, bias); - old_root_node = add; - new_ops.push_back(new_matmul); + // output of MatMul will be transposed comparing with original one, so the bias should be transposed too + if (bias->get_output_shape(0).size() > 1) { + bias = create_transpose(bias, bias->get_friendly_name() + "/transpose"); + new_ops.push_back(bias); + + auto transpose_shape = bias->get_output_shape(0); + auto matmul_shape = matmul_node->get_output_shape(0); + if (transpose_shape.size() > matmul_shape.size()) { + std::vector reshape_shape(matmul_shape.size(), 1); + std::copy_if(transpose_shape.begin(), transpose_shape.end(), reshape_shape.begin(), [](size_t e) { return e > 1; }); + bias = std::make_shared(bias, + std::make_shared(ngraph::element::Type_t::i64, + ngraph::Shape{reshape_shape.size()}, reshape_shape), false); + bias->set_friendly_name(add->get_friendly_name() + "/reshape"); + ngraph::copy_runtime_info(add, bias); + new_ops.push_back(bias); + } + } + + new_matmul = std::make_shared(new_matmul, bias); + old_root_node = add; + new_ops.push_back(new_matmul); } if (fq != nullptr) { @@ -70,113 +86,151 @@ static void SwapAndTransposeInputs(std::shared_ptr matmu new_ops.push_back(new_matmul); } - auto output = create_transpose(new_matmul, matmul_node->get_friendly_name()); + auto output = create_transpose(new_matmul, last_layer_name); new_ops.push_back(output); ngraph::copy_runtime_info(matmul_node, new_ops); ngraph::replace_node(old_root_node, output); } -SwapInputMatMul::SwapInputMatMul() { - MATCHER_SCOPE(SwapInputMatMul); - auto constant = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { - auto shape = node.get_node_shared_ptr()->get_output_shape(0); - if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) { - return false; - } - return true; - }); +static std::shared_ptr CreateMatmul( + bool is_first_constant, + ngraph::pattern::op::ValuePredicate const_predicate, + ngraph::pattern::op::ValuePredicate matmul_predicate = ngraph::pattern::has_static_shape()) { + auto constant = ngraph::pattern::wrap_type({}, const_predicate); auto fake_quantize = ngraph::pattern::wrap_type({constant, ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type()}); auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); - auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, - ngraph::pattern::has_static_shape()); - ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + if (is_first_constant) { + return ngraph::pattern::wrap_type( + {matmul_input, ngraph::pattern::any_input()}, matmul_predicate); + } + return ngraph::pattern::wrap_type( + {ngraph::pattern::any_input(), matmul_input}, matmul_predicate); +} + +static std::shared_ptr CreateMatmuls( + std::shared_ptr& matmul1, + std::shared_ptr& matmul2) { + matmul1 = CreateMatmul( + true, + [](const ngraph::Output& node) { return true; }, + [](const ngraph::Output& node) { + auto matmul_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); + IE_ASSERT(matmul_node != nullptr); + auto input_shape = matmul_node->get_input_shape(0); + return input_shape.size() == 2 && + (!matmul_node->get_transpose_a() && input_shape[0] > 8 || + matmul_node->get_transpose_a() && input_shape[1] > 8); }); + matmul2 = CreateMatmul( + false, + [](const ngraph::Output& node) { return true; }, + [](const ngraph::Output& node) { + auto matmul_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); + IE_ASSERT(matmul_node != nullptr); + auto first_input_shape = matmul_node->get_input_shape(0); + first_input_shape.erase(std::remove(first_input_shape.begin(), first_input_shape.end(), 1), first_input_shape.end()); + auto second_input_shape = matmul_node->get_input_shape(1); + return node.get_partial_shape().is_static() && + second_input_shape.size() == 2 && + (!matmul_node->get_transpose_b() && second_input_shape[1] <= 8 || + matmul_node->get_transpose_b() && second_input_shape[0] <= 8) && + first_input_shape.size() == 2 && + first_input_shape[0] > 8; }); + return std::make_shared(ngraph::OutputVector{matmul1, matmul2}); +} + +SwapInputMatMul::SwapInputMatMul() { + MATCHER_SCOPE(SwapInputMatMul); + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmuls(matmul1, matmul2); + auto callback = [=](ngraph::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto matmul_node = std::dynamic_pointer_cast(pattern_map.at(matmul).get_node_shared_ptr()); + auto iter = pattern_map.find(matmul1); + if (iter == pattern_map.end() && + (iter = pattern_map.find(matmul2)) == pattern_map.end()) { + return false; + } + + auto matmul_node = std::dynamic_pointer_cast(iter->second.get_node_shared_ptr()); IE_ASSERT(matmul_node != nullptr); - SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr); + SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr, ""); return true; }; - auto m = std::make_shared(matmul, matcher_name); - this->register_matcher(m, callback); + auto matcher = std::make_shared(matmul, "SwapInputMatMul"); + this->register_matcher(matcher, callback); } SwapInputMatMulWithBias::SwapInputMatMulWithBias() { MATCHER_SCOPE(SwapInputMatMulWithBias); - auto constant = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { - auto shape = node.get_node_shared_ptr()->get_output_shape(0); - if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) { - return false; - } - return true; - }); - auto fake_quantize = ngraph::pattern::wrap_type({constant, - ngraph::pattern::wrap_type(), - ngraph::pattern::wrap_type(), - ngraph::pattern::wrap_type(), - ngraph::pattern::wrap_type()}); - auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); - auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, - ngraph::pattern::has_static_shape()); + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmuls(matmul1, matmul2); auto bias = ngraph::pattern::wrap_type(); auto add = ngraph::pattern::wrap_type({matmul, bias}); - - ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto callback = [=](ngraph::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto matmul_node = std::dynamic_pointer_cast(pattern_map.at(matmul).get_node_shared_ptr()); + auto iter = pattern_map.find(matmul1); + if (iter == pattern_map.end() && + (iter = pattern_map.find(matmul2)) == pattern_map.end()) { + return false; + } + + auto matmul_node = std::dynamic_pointer_cast(iter->second.get_node_shared_ptr()); IE_ASSERT(matmul_node != nullptr); - SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(), - pattern_map.at(bias).get_node_shared_ptr(), nullptr); + SwapAndTransposeInputs( + matmul_node, + pattern_map.at(add).get_node_shared_ptr(), + pattern_map.at(bias).get_node_shared_ptr(), + nullptr, + pattern_map.at(add).get_node_shared_ptr()->get_friendly_name()); return true; }; - auto m = std::make_shared(add, matcher_name); - this->register_matcher(m, callback); + auto matcher = std::make_shared(add, "SwapInputMatMulWithBias"); + this->register_matcher(matcher, callback); } SwapInputMatMulWithFq::SwapInputMatMulWithFq() { MATCHER_SCOPE(SwapInputMatMulWithFq); - auto constant = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { - auto shape = node.get_node_shared_ptr()->get_output_shape(0); - if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) { - return false; - } - return true; - }); - auto fake_quantize = ngraph::pattern::wrap_type({constant, - ngraph::pattern::wrap_type(), - ngraph::pattern::wrap_type(), - ngraph::pattern::wrap_type(), - ngraph::pattern::wrap_type()}); - auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); - auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, - ngraph::pattern::has_static_shape()); + std::shared_ptr matmul1; + std::shared_ptr matmul2; + auto matmul = CreateMatmuls(matmul1, matmul2); auto bias = ngraph::pattern::wrap_type(); auto add = ngraph::pattern::wrap_type({matmul, bias}); - auto matmul_out = std::make_shared(ngraph::OutputVector{add, matmul}); - auto out_fq = ngraph::pattern::wrap_type({matmul_out, + auto fq_input = std::make_shared(ngraph::OutputVector{add, matmul}); + auto fq = ngraph::pattern::wrap_type({fq_input, ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type()}); - - ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto callback = [=](ngraph::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto matmul_node = std::dynamic_pointer_cast(pattern_map.at(matmul).get_node_shared_ptr()); + auto iter = pattern_map.find(matmul1); + if (iter == pattern_map.end() && + (iter = pattern_map.find(matmul2)) == pattern_map.end()) { + return false; + } + + auto iter_add = pattern_map.find(add); + auto iter_bias = pattern_map.find(bias); + auto matmul_node = std::dynamic_pointer_cast(iter->second.get_node_shared_ptr()); IE_ASSERT(matmul_node != nullptr); - auto add_it = pattern_map.find(add); - auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr()); - auto bias_it = pattern_map.find(bias); - auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr()); - SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr()); + SwapAndTransposeInputs( + matmul_node, + iter_add != pattern_map.end() ? iter_add->second.get_node_shared_ptr() : nullptr, + iter_bias != pattern_map.end() ? iter_bias->second.get_node_shared_ptr() : nullptr, + pattern_map.at(fq).get_node_shared_ptr(), + pattern_map.at(fq).get_node_shared_ptr()->get_friendly_name()); return true; }; - auto m = std::make_shared(out_fq, matcher_name); - this->register_matcher(m, callback); -} \ No newline at end of file + auto matcher = std::make_shared(fq, "SwapInputMatMulWithFq"); + this->register_matcher(matcher, callback); +} +} // namespace GNAPluginNS diff --git a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp index c9604f8b7c2545..aab88799064ebf 100644 --- a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp +++ b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp @@ -2,15 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 // -#pragma once +#ifndef SWAP_INPUT_MATMUL_GNA_HPP +#define SWAP_INPUT_MATMUL_GNA_HPP -#include -#include #include namespace GNAPluginNS { - -// @brief Swaps and transposes inputs of MatMul if its first input is const and its batch size isn't supported by GNA +// @brief Swaps and transposes inputs of MatMul if +// 1. its first input is const and its batch size isn't supported by GNA +// 2. its first input is non-const and its batch size isn't supported by GNA class SwapInputMatMul: public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; @@ -28,4 +28,6 @@ class SwapInputMatMulWithFq: public ngraph::pass::MatcherPass { NGRAPH_RTTI_DECLARATION; SwapInputMatMulWithFq(); }; -} // namespace GNAPluginNS \ No newline at end of file +} // namespace GNAPluginNS + +#endif // SWAP_INPUT_MATMUL_GNA_HPP diff --git a/inference-engine/tests/functional/plugin/gna/pass_tests/convert_matmul_to_fullyconnected.cpp b/inference-engine/tests/functional/plugin/gna/pass_tests/convert_matmul_to_fullyconnected.cpp index ddce7bb0dcf189..3efc3160a9d4fc 100644 --- a/inference-engine/tests/functional/plugin/gna/pass_tests/convert_matmul_to_fullyconnected.cpp +++ b/inference-engine/tests/functional/plugin/gna/pass_tests/convert_matmul_to_fullyconnected.cpp @@ -99,7 +99,8 @@ const std::vector>> input_shapes = { {{1, 8}, {8, 1}}, {{128, 8}, {8, 1}}, {{8, 8}, {8, 8}}, - {{1, 16}, {16, 8}} + {{1, 16}, {16, 8}}, + {{6, 16}, {16, 8}} }; @@ -110,4 +111,4 @@ INSTANTIATE_TEST_SUITE_P(smoke_convert_matmul_to_fc, ConvertMatmulToFcPass, ::testing::Values(CommonTestUtils::DEVICE_GNA), ::testing::ValuesIn(configs)), ConvertMatmulToFcPass::getTestCaseName); -} // namespace LayerTestsDefinitions \ No newline at end of file +} // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_insert_reshape_around_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_insert_reshape_around_matmul.cpp new file mode 100644 index 00000000000000..68422b5a9be3fd --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_insert_reshape_around_matmul.cpp @@ -0,0 +1,190 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/insert_reshape_around_matmul.hpp" + +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include +#include +#include + +template +struct InsertReshapeAroundMatmulTest { + static std::shared_ptr CreateAdd(std::shared_ptr input, const ngraph::Shape& constant_shape) { + std::vector data(ngraph::shape_size(constant_shape)); + std::iota(std::begin(data), std::end(data), 1); + auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, constant_shape, data); + return std::make_shared(input, constant); + } + + static std::shared_ptr CreateMatmul( + std::shared_ptr input, + const ngraph::Shape& matmul_constant_shape) { + std::vector data(ngraph::shape_size(matmul_constant_shape)); + std::iota(std::begin(data), std::end(data), 1); + auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, matmul_constant_shape, data); + std::shared_ptr node; + node = std::make_shared(input, constant); + + if (ADD) { + auto matmul_shape = node->get_output_shape(0); + data.resize(ngraph::shape_size(matmul_shape)); + std::iota(std::begin(data), std::end(data), 1); + std::vector constant_add_shape(2, 1); + std::copy_if(matmul_shape.begin(), matmul_shape.end(), constant_add_shape.begin(), [](size_t e) { return e > 1; }); + auto constant_add = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{constant_add_shape}, data); + if (ADD_FIRST_INPUT_NOT_CONSTANT) { + node = std::make_shared(node, constant_add); + } else { + node = std::make_shared(constant_add, node); + } + } + + if (FQ) { + node = std::make_shared( + node, + ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}), + ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}), + ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {-0.1}), + ngraph::opset8::Constant::create(ngraph::element::f32, {1}, {0.1}), + 255); + } + + return node; + } + + static std::shared_ptr CreateFunction( + const ngraph::Shape& input_shape, + const ngraph::Shape& matmul_constant_shape, + const ngraph::Shape& result_shape) { + auto input = std::make_shared(ngraph::element::i64, input_shape); + auto before = std::make_shared(input); + auto matmul = CreateMatmul(before, matmul_constant_shape); + auto after = std::make_shared(matmul); + return std::make_shared( + ngraph::ResultVector{std::make_shared(after)}, + ngraph::ParameterVector{input}); + } + + static std::shared_ptr CreateReferenceFunction( + const ngraph::Shape& input_shape, + const ngraph::Shape& reshape_before_shape, + const ngraph::Shape& matmul_constant_shape, + const ngraph::Shape& reshape_after_shape, + const ngraph::Shape& result_shape) { + auto input = std::make_shared(ngraph::element::i64, input_shape); + auto before = std::make_shared(input); + auto reshape_before_constant = ngraph::opset8::Constant::create(ngraph::element::i64, + ngraph::Shape{reshape_before_shape.size()}, reshape_before_shape); + auto reshape_before = std::make_shared(before, reshape_before_constant, false); + auto matmul = CreateMatmul(reshape_before, matmul_constant_shape); + auto reshape_after_constant = ngraph::opset8::Constant::create(ngraph::element::i64, + ngraph::Shape{reshape_after_shape.size()}, reshape_after_shape); + auto reshape_after = std::make_shared(matmul, reshape_after_constant, false); + auto after = std::make_shared(reshape_after); + return std::make_shared( + ngraph::ResultVector{std::make_shared(after)}, + ngraph::ParameterVector{input}); + } +}; // struct InsertReshapeAroundMatmulTest + +namespace { + +void RunTest(const std::shared_ptr& func, const std::shared_ptr& reference_func) { + { + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +} // namespace + +TEST(TransformationTests, InsertReshapeAroundMatmul) { + RunTest( + InsertReshapeAroundMatmulTest:: + CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateFunction({1, 6, 1, 8}, {8, 10}, {1, 6, 1, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10})); +} + +TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd) { + RunTest( + InsertReshapeAroundMatmulTest:: + CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); +} + +TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd_AddFirstInputConstant) { + RunTest( + InsertReshapeAroundMatmulTest:: + CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); +} + +TEST(TransformationTests, InsertReshapeAroundMatmulWithFq) { + RunTest( + InsertReshapeAroundMatmulTest:: + CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); +} + +TEST(TransformationTests, InsertReshapeAroundMatmulWithAddAndFq) { + RunTest( + InsertReshapeAroundMatmulTest:: + CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); + RunTest( + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}), + InsertReshapeAroundMatmulTest:: + CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10})); +} diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp index 184f0fac937896..40288d05785677 100644 --- a/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp @@ -20,7 +20,8 @@ static std::shared_ptr CreateMatMulFunction(const ngraph::Shap bool withBias, bool withWeightsFq, bool withOutFq, - bool swappedInputs) { + bool swappedInputs, + bool needTranspose) { auto input_params = std::make_shared(ngraph::element::i64, input2_shape); auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, input1_shape, {1}); @@ -33,14 +34,14 @@ static std::shared_ptr CreateMatMulFunction(const ngraph::Shap const_input = std::make_shared(const_input, input_low, input_high, output_low, output_high, 11); } - auto matmul = swappedInputs ? std::make_shared(input_params, const_input, true, true) : - std::make_shared(const_input, input_params); + auto matmul = swappedInputs ? std::make_shared(input_params, const_input, needTranspose, needTranspose) : + std::make_shared(const_input, input_params, needTranspose, needTranspose); std::shared_ptr final_node = matmul; if (withBias) { auto bias = ngraph::opset8::Constant::create(ngraph::element::i64, bias_shape, {1}); std::shared_ptr bias_node = bias; - if (swappedInputs && bias_shape.size() > 1) { + if (needTranspose && bias_shape.size() > 1) { auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector{1, 0}); bias_node = std::make_shared(bias_node, transpose_order); @@ -57,7 +58,7 @@ static std::shared_ptr CreateMatMulFunction(const ngraph::Shap output_low, output_high, 11); } - if (swappedInputs) { + if (needTranspose) { auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{2}, std::vector{1, 0}); final_node = std::make_shared(final_node, transpose_order); @@ -104,6 +105,12 @@ static std::string getTestCaseName(testing::TestParamInfo return result.str(); } +enum class MatmulInputType { + FirstInputConstant, + SecondInputConstant +}; // enum class MatmulInputType + +template class SwapInputMatmul : public CommonTestUtils::TestsCommon, public ::testing::WithParamInterface { public: @@ -112,14 +119,24 @@ class SwapInputMatmul : public CommonTestUtils::TestsCommon, bool withBias, withWeightsFq, withOutFq; std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam(); - function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false); + bool swap_inputs = false; + switch (E) { + case MatmulInputType::FirstInputConstant: + break; + case MatmulInputType::SecondInputConstant: + swap_inputs = true; + break; + } + + function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false); reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, - withOutFq, true); + withOutFq, !swap_inputs, true); } public: std::shared_ptr function, reference_function; }; +template class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon, public ::testing::WithParamInterface { public: @@ -128,42 +145,92 @@ class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon, bool withBias, withWeightsFq, withOutFq; std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam(); - function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false); + bool swap_inputs = false; + switch (E) { + case MatmulInputType::FirstInputConstant: + break; + case MatmulInputType::SecondInputConstant: + swap_inputs = true; + break; + } + + function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, swap_inputs, false); reference_function = ngraph::clone_function(*function); } public: std::shared_ptr function, reference_function; }; -TEST_P(SwapInputMatmul, CompareFunctions) { +using SwapInputMatmulWithFirstInputConstant = SwapInputMatmul; +using SwapInputMatmulWithSecondInputConstant = SwapInputMatmul; +using SwapInputMatmulWithFirstInputConstantNotApplied = SwapInputMatmulNotApplied; +using SwapInputMatmulWithSecondInputConstantNotApplied = SwapInputMatmulNotApplied; + +TEST_P(SwapInputMatmulWithFirstInputConstant, CompareFunctions) { Execute(function, reference_function); } -TEST_P(SwapInputMatmulNotApplied, CompareFunctions) { +TEST_P(SwapInputMatmulWithFirstInputConstantNotApplied, CompareFunctions) { Execute(function, reference_function); } -const std::vector> input_shapes_applied = { +TEST_P(SwapInputMatmulWithSecondInputConstant, CompareFunctions) { + Execute(function, reference_function); +} + +TEST_P(SwapInputMatmulWithSecondInputConstantNotApplied, CompareFunctions) { + Execute(function, reference_function); +} + +const std::vector> input_shapes_for_matmul_with_first_constant_applied = { {{16, 8}, {8, 8}, {16, 8}}, {{16, 8}, {8, 8}, {1}}, }; -const std::vector> input_shapes_not_applied = { +const std::vector> input_shapes_for_matmul_with_first_constant_not_applied = { {{1, 8}, {8, 8}, {1, 8}}, {{8}, {8, 8}, {8}} }; -INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmul, +const std::vector> input_shapes_for_matmul_with_second_constant_applied = { + {{64, 6}, {100, 64}, {100, 6}}, + {{64, 6}, {100, 64}, {1}}, +}; + +const std::vector> input_shapes_for_matmul_with_second_constant_not_applied = { + {{64, 16}, {100, 64}, {100, 16}}, + {{64, 6}, {8, 64}, {8, 6}}, + {{8, 1}, {8, 8}, {8, 1}}, + {{8}, {8, 8}, {8}} +}; + +INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstant, + ::testing::Combine( + ::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_applied), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true})), + getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithFirstInputConstantNotApplied, + ::testing::Combine( + ::testing::ValuesIn(input_shapes_for_matmul_with_first_constant_not_applied), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true})), + getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstant, ::testing::Combine( - ::testing::ValuesIn(input_shapes_applied), + ::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_applied), ::testing::ValuesIn(std::vector{false, true}), ::testing::ValuesIn(std::vector{false, true}), ::testing::ValuesIn(std::vector{false, true})), getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied, +INSTANTIATE_TEST_SUITE_P(smoke_swap_input_matmul, SwapInputMatmulWithSecondInputConstantNotApplied, ::testing::Combine( - ::testing::ValuesIn(input_shapes_not_applied), + ::testing::ValuesIn(input_shapes_for_matmul_with_second_constant_not_applied), ::testing::ValuesIn(std::vector{false, true}), ::testing::ValuesIn(std::vector{false, true}), ::testing::ValuesIn(std::vector{false, true})), diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp index df8ac77ed6e85d..cd35294579c09e 100644 --- a/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp +++ b/inference-engine/tests/unit/gna/ngraph/transformations/handle_transposes_around_matmul.cpp @@ -70,56 +70,117 @@ std::shared_ptr CreateMatmulFunction(const ngraph::Shape& inpu namespace handle_transpose_after_matmul { -std::shared_ptr CreateMatmulTransposeFunction(const ngraph::Shape& input_shape, - const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_after_transpose) { +std::shared_ptr CreateMatmulTransposeFunction( + const ngraph::Shape& input_shape, + const ngraph::Shape& matmul_shape, + const ngraph::Shape& reshape_shape, + bool create_reshape_after_transpose, + bool enable_last_reshape, + bool enable_add, + bool matmul_on_left_side, + bool enable_fq) { auto input_params = std::make_shared(ngraph::element::i64, input_shape); std::vector data(ngraph::shape_size(matmul_shape)); std::iota(std::begin(data), std::end(data), 1); auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data); - auto matmul = std::make_shared(input_params, matmul_constant); - const auto matmul_output_shape = matmul->get_output_shape(0); + std::shared_ptr node = std::make_shared(input_params, matmul_constant); + const auto matmul_output_shape = node->get_output_shape(0); + if (enable_add) { + auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1}); + if (matmul_on_left_side) { + node = std::make_shared(add_const, node); + } else { + node = std::make_shared(node, add_const); + } + } + + if (enable_fq) { + node = std::make_shared( + node, + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}), + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}), + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}), + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}), + 255); + } auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0}); - auto transpose = std::make_shared(matmul, transpose_order); + auto transpose = std::make_shared(node, transpose_order); const auto transpose_output_shape = transpose->get_output_shape(0); - std::shared_ptr reshape; + std::shared_ptr reshape; auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape); if (create_reshape_after_transpose) { - const auto matmul_output_shape = matmul->get_output_shape(0); + const auto matmul_output_shape = node->get_output_shape(0); auto reshape_after_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{matmul_output_shape.size()}, matmul_output_shape); auto reshape_after_transpose = std::make_shared(transpose, reshape_after_transpose_const, false); - reshape = std::make_shared(reshape_after_transpose, shape_const, false); + reshape = reshape_after_transpose; + if (enable_last_reshape) { + reshape = std::make_shared(reshape_after_transpose, shape_const, false); + } } else { - reshape = std::make_shared(transpose, shape_const, false); - const auto reshape_output_shape = reshape->get_output_shape(0); + reshape = transpose; + if (enable_last_reshape) { + reshape = std::make_shared(transpose, shape_const, false); + } } auto result = std::make_shared(reshape); return std::make_shared(ngraph::ResultVector{result}, ngraph::ParameterVector{input_params}); } -std::shared_ptr CreateMatmulFunction(const ngraph::Shape& input_shape, - const ngraph::Shape& matmul_shape, const ngraph::Shape& reshape_shape, bool create_reshape_instead_of_transpose) { +std::shared_ptr CreateMatmulFunction( + const ngraph::Shape& input_shape, + const ngraph::Shape& matmul_shape, + const ngraph::Shape& reshape_shape, + bool create_reshape_instead_of_transpose, + bool enable_last_reshape, + bool enable_add, + bool matmul_on_left_side, + bool enable_fq) { auto input_params = std::make_shared(ngraph::element::i64, input_shape); std::vector data(ngraph::shape_size(matmul_shape)); std::iota(std::begin(data), std::end(data), 1); auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data); - auto matmul = std::make_shared(input_params, matmul_constant); + std::shared_ptr node = std::make_shared(input_params, matmul_constant); + const auto matmul_output_shape = node->get_output_shape(0); + if (enable_add) { + auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1}); + if (matmul_on_left_side) { + node = std::make_shared(add_const, node); + } else { + node = std::make_shared(node, add_const); + } + } - std::shared_ptr reshape; + if (enable_fq) { + node = std::make_shared( + node, + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}), + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}), + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}), + ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}), + 255); + } + + std::shared_ptr reshape; auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape); if (create_reshape_instead_of_transpose) { - const auto matmul_output_shape = matmul->get_output_shape(0); auto reshape_instead_of_transpose_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{matmul_output_shape.size()}, {matmul_output_shape[1], matmul_output_shape[0]}); - auto reshape_instead_of_transpose = std::make_shared(matmul, reshape_instead_of_transpose_const, false); - reshape = std::make_shared(reshape_instead_of_transpose, shape_const, false); + auto reshape_instead_of_transpose = std::make_shared(node, reshape_instead_of_transpose_const, false); + reshape = reshape_instead_of_transpose; + if (enable_last_reshape) { + reshape = std::make_shared(reshape_instead_of_transpose, shape_const, false); + } } else { - reshape = std::make_shared(matmul, shape_const, false); + reshape = node; + if (enable_last_reshape) { + reshape = std::make_shared(node, shape_const, false); + } } auto result = std::make_shared(reshape); @@ -153,6 +214,9 @@ TEST(TransformationTests, InsertTransposeBeforeMatmulTest) { RunTest( handle_transpose_before_matmul::CreateMatmulFunction({1, 16}, {8, 2}, {2, 1}, false), handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 16}, {8, 2}, {2, 1}, true)); + RunTest( + handle_transpose_before_matmul::CreateMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, false), + handle_transpose_before_matmul::CreateTransposeMatmulFunction({1, 2, 8}, {8, 2}, {2, 1}, true)); } TEST(TransformationTests, InsertTransposeBeforeMatmulTestReshapeInOutEq) { @@ -177,25 +241,59 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) { } TEST(TransformationTests, InsertTransposeAfterMatmulTest) { - RunTest( - handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, false), - handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, true)); + for (auto enable_add : { true, false}) { + for (auto matmul_on_left_side : { true, false}) { + for (auto enable_fq : { true, false}) { + RunTest( + handle_transpose_after_matmul::CreateMatmulFunction( + {4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq), + handle_transpose_after_matmul::CreateMatmulTransposeFunction( + {4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq)); + } + } + } } TEST(TransformationTests, RemoveTransposeAfterMatmulTest) { - RunTest( - handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {2, 16}, false), - handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {2, 16}, true)); + for (auto enable_add : { true, false }) { + for (auto matmul_on_left_side : { true, false }) { + for (auto enable_fq : { true, false }) { + RunTest( + handle_transpose_after_matmul::CreateMatmulTransposeFunction( + {4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq), + handle_transpose_after_matmul::CreateMatmulFunction( + {4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq)); + } + } + } } TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) { - RunTest( - handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false), - handle_transpose_after_matmul::CreateMatmulTransposeFunction({4, 1}, {1, 8}, {8, 4}, false)); + for (auto enable_add : { true, false }) { + for (auto matmul_on_left_side : { true, false }) { + for (auto enable_fq : { true, false }) { + RunTest( + handle_transpose_after_matmul::CreateMatmulTransposeFunction( + {4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq), + handle_transpose_after_matmul::CreateMatmulTransposeFunction( + {4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq)); + } + } + } } TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) { - RunTest( - handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false), - handle_transpose_after_matmul::CreateMatmulFunction({4, 1}, {1, 8}, {4, 8}, false)); + for (auto enable_last_reshape : { true, false }) { + for (auto enable_add : { true, false }) { + for (auto matmul_on_left_side : { true, false }) { + for (auto enable_fq : { true, false }) { + RunTest( + handle_transpose_after_matmul::CreateMatmulFunction( + {4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq), + handle_transpose_after_matmul::CreateMatmulFunction( + {4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq)); + } + } + } + } }