Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANSFORMATIONS][GPU] Add GroupNormalization fusion to common optimizations #28387

Open
wants to merge 45 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c199677
Add GroupNormalization fusion to common optimizations
jhajducz Jan 11, 2025
364c31b
Add GroupNormalization fusion tests
jhajducz Jan 11, 2025
309745b
Enable GroupNormalization fusion pass in GPU plugin
jhajducz Jan 11, 2025
7b979a2
Update copyright notice
jhajducz Jan 13, 2025
f03f05c
Refactor GroupNormalizationFusion tests to avoid changes in core API
jhajducz Jan 14, 2025
27c8e70
Remove GPU plugin specific GroupNormComposition pass
jhajducz Jan 16, 2025
46e7e09
Fix RTTI macro in GroupNormalizationFusion header file
jhajducz Jan 16, 2025
7fff259
Override TestBody() method in GroupNormalizationFusionTestsFixture
jhajducz Jan 16, 2025
5adf805
Explain meaning of GroupNormalizationFusion tests parameters
jhajducz Jan 16, 2025
404d459
Require providing correct group norm gamma & beta shapes in positive …
jhajducz Jan 16, 2025
a582a50
Use dedicated Constant ctor to create scalar constants in GroupNormal…
jhajducz Jan 16, 2025
f7baaf8
Avoid Shape->PartialShape conversion for in/out tensors in GroupNorma…
jhajducz Jan 16, 2025
f1c9d73
Use global testing namespace in GroupNormalizationFusion tests
jhajducz Jan 16, 2025
5a64d7b
Another update of copyright notice
jhajducz Jan 16, 2025
2819968
Use const references where possible in GroupNormalizationFusion pass
jhajducz Jan 16, 2025
ec9a4db
Move GroupNormalizationFusion after MVNFusion pass in GPU plugin tran…
jhajducz Jan 16, 2025
09b605f
Use OV ptr cast for MVN in GroupNormalizationFusion pass
jhajducz Jan 16, 2025
93a63a1
Add 5d and 6d cases to GroupNormalizationFusion tests + fix formatting
jhajducz Jan 16, 2025
537f2de
Use predicates for type & shape checks that don't depend on other nod…
jhajducz Jan 16, 2025
3043c59
Use ov::pass::pattern namespace in GroupNormalizationFusion pass
jhajducz Jan 17, 2025
6aca6a3
Remove redundant has_integral_type predicate from GroupNormalizationF…
jhajducz Jan 17, 2025
d8ae536
Simplify accessing nodes partial shapes in GroupNormalizationFusion pass
jhajducz Jan 17, 2025
2cd179d
Fix typo in one of types in GroupNormalizationFusion tests
jhajducz Jan 20, 2025
ab5c920
Remove unused include files from GroupNormalizationFusion pass
jhajducz Feb 4, 2025
e410b2d
Fix handling instance norm gamma & beta in GroupNormalizationFusion pass
jhajducz Feb 4, 2025
200f7fc
Validate pre-MVN shape and MVN reduction axes in GroupNormalizationFu…
jhajducz Feb 4, 2025
325f0f1
Make instance norm gamma & beta explicitly optional in GroupNormaliza…
jhajducz Feb 4, 2025
dfc3056
Add GroupNormalizationFusion shared functional subgraph test
jhajducz Feb 4, 2025
7da1fc6
Add instance of GroupNormalizationFusion shared functional subgraph t…
jhajducz Feb 4, 2025
cafec08
Refactor GroupNormalizationFusion transformation test
jhajducz Feb 4, 2025
956fec3
Add missing include file in GroupNormalizationFusion shared functiona…
jhajducz Feb 4, 2025
6a9a72a
Cosmetic changes in ov::test::SubgraphBaseTest class
jhajducz Feb 4, 2025
5a2aa33
Remove redundant virtual keyword in GroupNormalizationFusion shared f…
jhajducz Feb 4, 2025
10530f4
Fix accessing type and members variables/functions from GroupNormaliz…
jhajducz Feb 5, 2025
67ffd4e
Add missing override keywords in GroupNormalizationFusion shared func…
jhajducz Feb 5, 2025
aa577f1
Fix usage of ov::element::Type_t and ov::element::Type in GroupNormal…
jhajducz Feb 5, 2025
7aac40a
Fix comparison of integer expressions of different signedness in Grou…
jhajducz Feb 5, 2025
ee9b1d2
Override init_thresholds() in MHA shared functional test class
jhajducz Feb 5, 2025
0cb7da4
Simplify pre-MVN shape and MVN reduction axes checks to avoid switch …
jhajducz Feb 13, 2025
921e5de
Remove comments in self-explanatory parts of code
jhajducz Feb 13, 2025
073cf0e
Remove unnecessary cast in GroupNormalizationFusion pass
jhajducz Feb 13, 2025
46a142a
Use lambas for pre-MVN shape and MVN reduction axes checks in GroupNo…
jhajducz Feb 13, 2025
1a4edfd
Remove comment describing what GroupNormalizationFusion does from GPU…
jhajducz Feb 13, 2025
cce734a
Remove redundant ov:: namespace prefixes in GroupNormalizationFusion …
jhajducz Feb 13, 2025
30278fe
Use size_t as data type for number of groups in GroupNormalizationFus…
jhajducz Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {
jhajducz marked this conversation as resolved.
Show resolved Hide resolved

class TRANSFORMATIONS_API GroupNormalizationFusion;

} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief GroupNormalizationFusion transformation replaces
* following pattern with fused GroupNormalization op:
* group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta
* note that instance norm related parameters are optional:
* - instance_norm_gamma is assumed to be filled with ones if not present in the graph
* - instance_norm_beta is assumed to be filled with zeros if not present in the graph
*/

class ov::pass::GroupNormalizationFusion : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("GroupNormalizationFusion");
GroupNormalizationFusion();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/group_normalization_fusion.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/group_normalization.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/mvn.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/pass/pattern/op/optional.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::pass::pattern;

ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() {
jhajducz marked this conversation as resolved.
Show resolved Hide resolved
MATCHER_SCOPE(GroupNormalizationFusion);

auto has_real_not_quantized_type = [](const ov::Output<ov::Node>& output) -> bool {
const auto& T = output.get_element_type();
return (T.is_real() && (!T.is_quantized()));
};

auto has_at_least_2d_shape = [](const ov::Output<ov::Node>& output) -> bool {
const auto& output_ps = output.get_partial_shape();
return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2);
};

auto input_m = any_input(all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)}));

auto pre_mvn_shape_const_m = wrap_type<ov::op::v0::Constant>(all_of({rank_equals(1), has_static_dim(0)}));
auto pre_mvn_reshape_m =
wrap_type<ov::op::v1::Reshape>({input_m, pre_mvn_shape_const_m},
all_of({has_real_not_quantized_type, rank_equals(3), has_static_dim(1)}));

auto mvn_reduction_axes_const_m = wrap_type<ov::op::v0::Constant>(all_of({rank_equals(1), has_static_dim(0)}));
auto mvn_m = wrap_type<ov::op::v6::MVN>({pre_mvn_reshape_m, mvn_reduction_axes_const_m});

auto instance_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
auto instance_norm_opt_gamma_m = optional<ov::op::v1::Multiply>({mvn_m, instance_norm_gamma_m});

auto instance_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
auto instance_norm_opt_gamma_opt_beta_m =
optional<ov::op::v1::Add>({instance_norm_opt_gamma_m, instance_norm_beta_m});

auto post_instance_norm_shape_m = any_input(all_of({rank_equals(1), has_static_dim(0)}));
auto post_instance_norm_reshape_m =
wrap_type<ov::op::v1::Reshape>({instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m},
all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)}));

auto group_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
auto group_norm_gamma_multiply_m =
wrap_type<ov::op::v1::Multiply>({post_instance_norm_reshape_m, group_norm_gamma_m});

auto group_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()}));
auto group_norm_beta_add_m = wrap_type<ov::op::v1::Add>({group_norm_gamma_multiply_m, group_norm_beta_m});

ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();

const auto& input = pattern_map.at(input_m);
const auto& input_ps = input.get_partial_shape();

const auto& T = input.get_element_type();

const auto& pre_mvn_reshape_out_ps = pattern_map.at(pre_mvn_reshape_m).get_partial_shape();

const size_t num_channels = static_cast<size_t>(input_ps[1].get_max_length());
const size_t num_groups = static_cast<size_t>(pre_mvn_reshape_out_ps[1].get_max_length());

// we expect to reshape input in a way that would merge all spatial dimensions
// but leave batch and channel dimensions untouched
const auto& pre_mvn_shape = pattern_map.at(pre_mvn_shape_const_m);
const auto& pre_mvn_shape_const =
ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(pre_mvn_shape_const_m).get_node_shared_ptr());
const auto& pre_mvn_shape_out_ps = pre_mvn_shape.get_shape();
if (pre_mvn_shape_out_ps[0] != 3)
return false;

auto pre_mvn_shape_vals_correct = [](const std::vector<int64_t>& pre_mvn_shape_vals,
const ov::PartialShape& input_ps,
const ov::Dimension::value_type num_groups) -> bool {
bool res = true;
if (input_ps[0].is_dynamic()) {
if (pre_mvn_shape_vals[0] != 0ll)
res = false;
} else {
if ((pre_mvn_shape_vals[0] != 0ll) &&
(pre_mvn_shape_vals[0] != static_cast<long long>(input_ps[0].get_max_length())))
res = false;
}
if ((pre_mvn_shape_vals[1] != 0ll) && (pre_mvn_shape_vals[1] != static_cast<long long>(num_groups)))
res = false;
if (pre_mvn_shape_vals[2] != -1ll)
res = false;
return res;
};

if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const->cast_vector<int64_t>(), input_ps, num_groups))
return false;

// number of channels has to be divisible by number of groups
if (num_channels % num_groups != 0)
return false;
auto channels_to_groups_ratio = num_channels / num_groups;

// first dimension of MVN input (batch_size) has to be the same
// as in pattern input
if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length())
return false;

// we expect to execute normalization over last dimension of MVN input
const auto& mvn_reduction_axes = pattern_map.at(mvn_reduction_axes_const_m);
const auto& mvn_reduction_axes_const =
ov::as_type_ptr<ov::op::v0::Constant>(mvn_reduction_axes.get_node_shared_ptr());
const auto& mvn_reduction_axes_out_shape = mvn_reduction_axes.get_shape();
if (mvn_reduction_axes_out_shape[0] != 1)
return false;

auto mvn_reduction_axes_correct = [](const std::vector<int64_t>& mvn_reduction_axes) -> bool {
bool res = true;
if ((mvn_reduction_axes[0] != 2ll) && (mvn_reduction_axes[0] != -1ll))
return false;
return res;
};

if (!mvn_reduction_axes_correct(mvn_reduction_axes_const->cast_vector<int64_t>()))
return false;

const auto& post_instance_norm_reshape_out_ps =
pattern_map.at(post_instance_norm_reshape_m).get_partial_shape();
// post instance norm shape has to be same as in pattern input
if (post_instance_norm_reshape_out_ps != input_ps)
return false;

const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m);
if (group_norm_gamma.get_element_type() != T)
return false;
if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels)
return false;

const auto& group_norm_beta = pattern_map.at(group_norm_beta_m);
if (group_norm_beta.get_element_type() != T)
return false;
if (ov::shape_size(group_norm_beta.get_shape()) != num_channels)
return false;

auto expected_param_shape = ov::PartialShape({static_cast<ov::Dimension>(num_channels)});

std::shared_ptr<ov::Node> group_norm_gamma_1d_m = std::make_shared<ov::op::v0::Squeeze>(group_norm_gamma);
const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape(0);

if (group_norm_gamma_1d_out_ps != expected_param_shape)
return false;

std::shared_ptr<ov::Node> group_norm_beta_1d_m = std::make_shared<ov::op::v0::Squeeze>(group_norm_beta);
const auto& group_norm_beta_1d_out_ps = group_norm_beta_1d_m->get_output_partial_shape(0);

if (group_norm_beta_1d_out_ps != expected_param_shape)
return false;

auto gather_axis_const_m = op::v0::Constant::create(element::i64, Shape{1}, {0});
auto gather_indices_vals = std::vector<int64_t>();
for (auto i = 0ull; i < num_groups; i++)
gather_indices_vals.insert(gather_indices_vals.end(), channels_to_groups_ratio, i);
auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{num_channels}, gather_indices_vals);

std::shared_ptr<ov::Node> instance_norm_beta_1d_m = nullptr;
if (pattern_map.count(instance_norm_beta_m) > 0) {
const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m);
if (instance_norm_beta.get_element_type() != T)
return false;
if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups)
return false;

// ensure that instance_norm_beta will have shape compatible
// with group_norm parameters, i.e. 1D vector of shape (num_channels)
if (ov::shape_size(instance_norm_beta.get_shape()) == 1) {
auto shape_1d_const_m = op::v0::Constant::create(element::i64, Shape{1}, {1});
instance_norm_beta_1d_m =
std::make_shared<ov::op::v1::Reshape>(instance_norm_beta, shape_1d_const_m, true);
} else {
instance_norm_beta_1d_m = std::make_shared<ov::op::v0::Squeeze>(instance_norm_beta);
}

instance_norm_beta_1d_m = std::make_shared<ov::op::v8::Gather>(instance_norm_beta_1d_m,
gather_indices_const_m,
gather_axis_const_m);

const auto& instance_norm_beta_1d_ps = instance_norm_beta_1d_m->get_output_partial_shape(0);
if (instance_norm_beta_1d_ps != expected_param_shape)
return false;

// group_norm_beta = group_norm_gamma * instance_norm_beta + group_norm_beta
auto group_norm_beta_corr_multiply_m =
std::make_shared<ov::op::v1::Multiply>(group_norm_gamma_1d_m, instance_norm_beta_1d_m);
group_norm_beta_1d_m =
std::make_shared<ov::op::v1::Add>(group_norm_beta_corr_multiply_m, group_norm_beta_1d_m);
}

if (pattern_map.count(instance_norm_gamma_m) > 0) {
const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m);
jhajducz marked this conversation as resolved.
Show resolved Hide resolved
if (instance_norm_gamma.get_element_type() != T)
return false;
if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups)
return false;

// ensure that instance_norm_gamma will have shape compatible
// with group_norm parameters, i.e. 1D vector of shape (num_channels)
std::shared_ptr<ov::Node> instance_norm_gamma_1d_m = nullptr;
if (ov::shape_size(instance_norm_gamma.get_shape()) == 1) {
auto shape_1d_const_m = op::v0::Constant::create(element::i64, Shape{1}, {1});
instance_norm_gamma_1d_m =
std::make_shared<ov::op::v1::Reshape>(instance_norm_gamma, shape_1d_const_m, true);
} else {
instance_norm_gamma_1d_m = std::make_shared<ov::op::v0::Squeeze>(instance_norm_gamma);
}

instance_norm_gamma_1d_m = std::make_shared<ov::op::v8::Gather>(instance_norm_gamma_1d_m,
gather_indices_const_m,
gather_axis_const_m);
const auto& instance_norm_gamma_1d_ps = instance_norm_gamma_1d_m->get_output_partial_shape(0);
if (instance_norm_gamma_1d_ps != expected_param_shape)
return false;

// group_norm_gamma *= instance_norm_gamma
group_norm_gamma_1d_m =
std::make_shared<ov::op::v1::Multiply>(group_norm_gamma_1d_m, instance_norm_gamma_1d_m);
}

// we need to cast mvn to MVN layer type in order to read actual epsilon value
const auto& mvn_out = pattern_map.at(mvn_m);
const auto& mvn = ov::as_type_ptr<ov::op::v6::MVN>(mvn_out.get_node_shared_ptr());
const auto& epsilon = mvn->get_eps();

// we can finally create GroupNormalization op
std::shared_ptr<ov::Node> group_norm = std::make_shared<ov::op::v12::GroupNormalization>(input,
group_norm_gamma_1d_m,
group_norm_beta_1d_m,
num_groups,
epsilon);

// and do actual graph substitution
group_norm->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), group_norm);
ov::replace_node(m.get_match_root(), group_norm);
return true;
};

auto m = std::make_shared<Matcher>(group_norm_beta_add_m, matcher_name);
this->register_matcher(m, callback);
}
Loading
Loading