Skip to content

Commit

Permalink
[GNA] Expanding transformations: swap_input_matmul and handle_transpo…
Browse files Browse the repository at this point in the history
…ses_around_matmul (openvinotoolkit#7333)

* Expanding transformations: swap_input_matmul and handle_transposes_around_matmul

* insert_reshape_around_matmul

* fixed failed of smoke tests
  • Loading branch information
dmitriikhurtin authored Sep 14, 2021
1 parent 39120a7 commit ba34a19
Show file tree
Hide file tree
Showing 11 changed files with 923 additions and 204 deletions.
7 changes: 6 additions & 1 deletion inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "transformations/handle_transposes_around_matmul.hpp"
#include "transformations/decompose_2d_conv.hpp"
#include "transformations/convert_padded2valid_conv.hpp"
#include "transformations/insert_reshape_around_matmul.hpp"
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
#include "transformations/remove_single_input_concat.hpp"

Expand Down Expand Up @@ -730,10 +731,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
manager.register_pass<SplitConvolutionWithFq>();
manager.register_pass<SplitConvolutionWithBias>();
manager.register_pass<SplitConvolution>();
manager.register_pass<HandleTransposesAroundMatMul>();
manager.register_pass<InsertReshapeAroundMatmulWithTranspose>();
manager.register_pass<InsertReshapeAroundMatmulWithFq>();
manager.register_pass<InsertReshapeAroundMatmulWithAdd>();
manager.register_pass<InsertReshapeAroundMatmul>();
manager.register_pass<SwapInputMatMulWithFq>();
manager.register_pass<SwapInputMatMulWithBias>();
manager.register_pass<SwapInputMatMul>();
manager.register_pass<HandleTransposesAroundMatMul>();
manager.register_pass<InsertTransposeAfterConvOrPool>();
manager.register_pass<ReorderActivationAndPooling>();
manager.register_pass<RemoveSingleInputConcat>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,33 @@

#include <numeric>

#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <openvino/cc/ngraph/itt.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ie/ie_common.h>

#include "gna_plugin_log.hpp"
#include "backend/gna_limitations.hpp"

using namespace GNAPluginNS;
namespace GNAPluginNS {

NGRAPH_RTTI_DEFINITION(HandleTransposesAroundMatMul, "HandleTransposesAroundMatMul", 0);
NGRAPH_RTTI_DEFINITION(HandleTransposeBeforeMatMul, "HandleTransposeBeforeMatMul", 0);
NGRAPH_RTTI_DEFINITION(HandleTransposeAfterMatMul, "HandleTransposeAfterMatMul", 0);

static void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
auto shape = transpose_node->get_output_shape(0);
auto reshape_const = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{shape.size()}, shape);
auto reshape_node = std::make_shared<ngraph::opset7::Reshape>(transpose_node->input_value(0), reshape_const, false);
reshape_node->set_friendly_name(transpose_node->get_friendly_name() + "/reshape");
auto reshape_node = std::make_shared<ngraph::opset8::Reshape>(transpose_node->input_value(0), reshape_const, false);
reshape_node->set_friendly_name(transpose_node->get_friendly_name());
ngraph::copy_runtime_info(transpose_node, reshape_node);
transpose_node->output(0).replace(reshape_node->output(0));
}

static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
auto consumers = prev_node->output(0).get_target_inputs();
const auto orig_shape = prev_node->get_output_shape(0);
std::vector<size_t> transpose_ids;
Expand All @@ -44,13 +46,13 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
std::iota(std::begin(permute_order), std::end(permute_order), 0);
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);

auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
auto transpose = std::make_shared<ngraph::opset7::Transpose>(prev_node, transpose_order);
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
auto transpose = std::make_shared<ngraph::opset8::Transpose>(prev_node, transpose_order);
transpose->set_friendly_name(base_name + "/in_transpose");

auto reshapeConstAfter = std::make_shared<ngraph::opset7::Constant>(ngraph::element::Type_t::i64,
auto reshapeConstAfter = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{orig_shape.size()}, orig_shape);
auto reshapeAfter = std::make_shared<ngraph::opset7::Reshape>(transpose, reshapeConstAfter, false);
auto reshapeAfter = std::make_shared<ngraph::opset8::Reshape>(transpose, reshapeConstAfter, false);
reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose");
ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter});

Expand All @@ -59,74 +61,102 @@ static void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::
}
}

static bool VerifyReshape(const ngraph::Output<ngraph::Node>& reshape_out) {
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);
return in_shape[0] != out_shape[0];
}

HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({ngraph::pattern::any_input(),
ngraph::pattern::any_input()}, VerifyReshape());
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({reshape,
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>({}, VerifyReshape);
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({reshape,
ngraph::pattern::any_input()});
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose}),
ngraph::pattern::any_input()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose});
auto matmul1 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()});
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(), matmul_input});
auto matmul2 = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fq}),
std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{reshape, transpose, ngraph::pattern::any_input()})});
auto matmul = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul1, matmul2});

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
auto transpose_it = pattern_map.find(transpose);
if (transpose_it != std::end(pattern_map)) {
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
} else {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
const auto& pattern_map = matcher.get_pattern_value_map();
auto matmul_iter = pattern_map.find(matmul1);
if (matmul_iter == std::end(pattern_map) &&
(matmul_iter = pattern_map.find(matmul2)) == std::end(pattern_map)) {
return false;
}

auto transpose_reshape_it = pattern_map.find(transpose);
if (transpose_reshape_it != std::end(pattern_map)) {
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
} else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
auto matmul_it = pattern_map.find(matmul1);
auto matmul_out = matmul_it != std::end(pattern_map) ? matmul_it->second : pattern_map.at(matmul2);
InsertTranspose(reshape_node, matmul_out.get_node_shared_ptr()->get_friendly_name());
if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) {
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
InsertTranspose(reshape_node, matmul_node->get_friendly_name());
}
}

auto iter = pattern_map.find(fq);
if (iter != pattern_map.end() ||
(iter = pattern_map.find(constant)) != pattern_map.end()) {
auto prev_node = iter->second.get_node_shared_ptr();
if (!GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) return false;
auto matmul_node = iter->second.get_node_shared_ptr();
InsertTranspose(prev_node, matmul_node->get_friendly_name());
}
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
this->register_matcher(m, callback);
auto matcher = std::make_shared<ngraph::pattern::Matcher>(matmul, "HandleTransposeBeforeMatMul");
this->register_matcher(matcher, callback);
}

HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>();
auto fq = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({matmul, ngraph::pattern::any_input(),
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>();
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq});
auto transpose = ngraph::pattern::wrap_type<ngraph::opset7::Transpose>({transpose_input, ngraph::pattern::any_input()});
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>({reshape_input,
ngraph::pattern::any_input()}, VerifyReshape());
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(
{reshape_input, ngraph::pattern::any_input()}, VerifyReshape);

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &matcher) {
const auto& pattern_map = matcher.get_pattern_value_map();
auto transpose_it = pattern_map.find(transpose);
if (transpose_it != std::end(pattern_map)) {
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
} else {
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
auto matmul_node = pattern_map.at(matmul).get_node_shared_ptr();
InsertTranspose(matmul_node, matmul_node->get_friendly_name());
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
auto iter = pattern_map.find(fq);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
return false;
}
auto node = iter->second.get_node_shared_ptr();
InsertTranspose(node, node->get_friendly_name());
}
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
this->register_matcher(m, callback);
}

bool VerifyReshape::operator()(const ngraph::Output<ngraph::Node>& reshape_out) const {
auto in_shape = reshape_out.get_node_shared_ptr()->get_input_shape(0);
auto out_shape = reshape_out.get_node_shared_ptr()->get_output_shape(0);

// Check if Reshape changes the final 2d shape of Affine primitive
in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), in_shape.end());
out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end());
return in_shape != out_shape;
auto matcher = std::make_shared<ngraph::pattern::Matcher>(reshape, "HandleTransposeAfterMatMul");
this->register_matcher(matcher, callback);
}

HandleTransposesAroundMatMul::HandleTransposesAroundMatMul() {
add_matcher<HandleTransposeBeforeMatMul>();
add_matcher<HandleTransposeAfterMatMul>();
}

} // namespace GNAPluginNS
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

namespace GNAPluginNS {

struct VerifyReshape {
bool operator()(const ngraph::Output<ngraph::Node>& reshape_out) const;
};

/**
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
* before MatMul which changes the batch size:
Expand Down Expand Up @@ -48,16 +44,16 @@ class HandleTransposeBeforeMatMul : public ngraph::pass::MatcherPass {
* | |
* [1, A*B] [1, A*B]
*/
class HandleTransposeAfterMatMul : public ngraph::pass::MatcherPass {
class HandleTransposeAfterMatMul: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
HandleTransposeAfterMatMul();
NGRAPH_RTTI_DECLARATION;
HandleTransposeAfterMatMul();
};

class HandleTransposesAroundMatMul: public ngraph::pass::GraphRewrite {
class HandleTransposesAroundMatMul : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
HandleTransposesAroundMatMul();
NGRAPH_RTTI_DECLARATION;
HandleTransposesAroundMatMul();
};

} // namespace GNAPluginNS
Loading

0 comments on commit ba34a19

Please sign in to comment.