Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LeakyReluFusion transformation #6816

Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API LeakyReluFusion;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief LeakyReluFusion transformation replaces following graph:
* Multiply->Maximum to LeakyRelu
*/

class ngraph::pass::LeakyReluFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
LeakyReluFusion();
};
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "transformations/common_optimizations/swish_fusion.hpp"
#include "transformations/common_optimizations/normalize_l2_fusion.hpp"
#include "transformations/common_optimizations/pull_transpose_through_fq.hpp"
#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
#include "transformations/common_optimizations/hsigmoid_fusion.hpp"
Expand Down Expand Up @@ -133,6 +134,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
common_fusions->add_matcher<ngraph::pass::DilatedConvolutionConverter>();
common_fusions->add_matcher<ngraph::pass::GeluFusion>();
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->add_matcher<ngraph::pass::LeakyReluFusion>();
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/leaky_relu_fusion.hpp"
#include "transformations/utils/utils.hpp"

#include <memory>
#include <vector>

#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"


NGRAPH_RTTI_DEFINITION(ngraph::pass::LeakyReluFusion, "LeakyReluFusion", 0);

ngraph::pass::LeakyReluFusion::LeakyReluFusion() {
MATCHER_SCOPE(LeakyReluFusion);
auto data_pattern = ngraph::pattern::any_input();
auto alpha_pattern = ngraph::pattern::any_input(pattern::has_static_shape());
auto multiply_pattern = ngraph::pattern::wrap_type<opset8::Multiply>({data_pattern, alpha_pattern}, pattern::consumers_count(1));
auto max_pattern = ngraph::pattern::wrap_type<opset8::Maximum>({data_pattern, multiply_pattern});

ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map.at(data_pattern);
const auto & original_alpha_pattern = pattern_map.at(alpha_pattern);

if (shape_size(original_alpha_pattern.get_shape()) != 1)
return false;

auto leaky_relu = register_new_node<ngraph::opset8::PRelu>(data, original_alpha_pattern);
auto maximum = pattern_map.at(max_pattern);
leaky_relu->set_friendly_name(maximum.get_node()->get_friendly_name());

copy_runtime_info({
pattern_map.at(multiply_pattern).get_node_shared_ptr(),
maximum.get_node_shared_ptr()
},
leaky_relu);
replace_node(maximum.get_node_shared_ptr(), leaky_relu);

return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(max_pattern, matcher_name);
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <string>
#include <memory>
#include <queue>

#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <transformations/common_optimizations/leaky_relu_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"


using namespace testing;
using namespace ngraph;

TEST(TransformationTests, LeakyReluFusionConstant) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data});

pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::LeakyReluFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{1}, {0.1});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, LeakyReluFusionScalar) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data});

pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::LeakyReluFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto alpha = opset8::Constant::create(element::f32, Shape{}, {0.1});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

TEST(TransformationTests, LeakyReluFusionParameter) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{2, 2});
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
auto multiply = std::make_shared<opset8::Multiply>(data, alpha);
auto max = std::make_shared<opset8::Maximum>(data, multiply);
f = std::make_shared<Function>(NodeVector{max}, ParameterVector{data, alpha});

pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::LeakyReluFusion>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 2});
auto alpha = std::make_shared<opset8::Parameter>(element::f32, Shape{});
auto leaky_relu = std::make_shared<opset8::PRelu>(data, alpha);
f_ref = std::make_shared<Function>(NodeVector{leaky_relu}, ParameterVector{data, alpha});
}

auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}