From 632dfa0d59a305c319d081d0d20d90bdb5295a8a Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Sat, 11 Jan 2025 01:26:33 +0000 Subject: [PATCH 01/45] Add GroupNormalization fusion to common optimizations --- .../group_normalization_fusion.hpp | 32 ++ .../group_normalization_fusion.cpp | 288 ++++++++++++++++++ 2 files changed, 320 insertions(+) create mode 100644 src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp diff --git a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp new file mode 100644 index 00000000000000..d7ad56946295c7 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API GroupNormalizationFusion; + +} // namespace pass +} // namespace ov + +/** + * @ingroup ov_transformation_common_api + * @brief GroupNormalizationFusion transformation replaces + * following pattern with fused GroupNormalization op: + * group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta + * note that instance norm related parameters are optional: + * - instance_norm_gamma is assumed to be filled with ones if not present in the graph + * - instance_norm_beta is assumed to be filled with zeros if not present in the graph + */ + +class ov::pass::GroupNormalizationFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("GroupNormalizationFusion", "0"); + GroupNormalizationFusion(); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp new file mode 100644 index 00000000000000..50a04c459b989e --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -0,0 +1,288 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/group_normalization_fusion.hpp" + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/group_normalization.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/mvn.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { + MATCHER_SCOPE(GroupNormalizationFusion); + + auto input_m = ov::pass::pattern::any_input(); + + auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type(); + auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type({input_m, pre_mvn_shape_const_m}); + + auto axes_const_m = ov::pass::pattern::wrap_type(); + auto mvn_m = ov::pass::pattern::wrap_type({pre_mvn_reshape_m, axes_const_m}); + + auto instance_norm_gamma_m = ov::pass::pattern::any_input(); + auto instance_norm_gamma_multiply_m = + ov::pass::pattern::wrap_type({mvn_m, instance_norm_gamma_m}); + auto instance_norm_opt_gamma_m = + std::make_shared(ov::OutputVector{mvn_m, instance_norm_gamma_multiply_m}); + + auto instance_norm_beta_m = ov::pass::pattern::any_input(); + auto instance_norm_beta_add_m = + ov::pass::pattern::wrap_type({instance_norm_opt_gamma_m, instance_norm_beta_m}); + auto instance_norm_opt_gamma_opt_beta_m = std::make_shared( + ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m}); + + auto post_instance_norm_shape_m = ov::pass::pattern::any_input(); + auto post_instance_norm_reshape_m = ov::pass::pattern::wrap_type( + {instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}); + + auto group_norm_gamma_m = ov::pass::pattern::any_input(); + auto group_norm_gamma_multiply_m = + ov::pass::pattern::wrap_type({post_instance_norm_reshape_m, group_norm_gamma_m}); + + auto group_norm_beta_m = ov::pass::pattern::any_input(); + auto group_norm_beta_add_m = + ov::pass::pattern::wrap_type({group_norm_gamma_multiply_m, group_norm_beta_m}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + auto input = pattern_map.at(input_m); + auto input_ps = input.get_partial_shape(); + + auto T = input.get_element_type(); + + // this pattern supports only real and not quantized data types + if ((!T.is_real()) || (T.is_quantized())) + return false; + + // expecting at least 2D tensor as pattern input: + // (batch_size, num_channels, ...) + if (input_ps.size() < 2) + return false; + // channel dimension has to be static, all other dimensions in input can be dynamic + if (input_ps[1].is_dynamic()) + return false; + + auto pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m); + auto pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape(); + + // expecting 3D static tensor as pre-MVN reshape input: + // (batch_size, num_groups, -1) + if (pre_mvn_reshape_out_ps.size() != 3) + return false; + + // channel dimension has to be static, all other dimensions in pre-MVN reshape can be dynamic + if (pre_mvn_reshape_out_ps[1].is_dynamic()) + return false; + + auto num_channels = input_ps[1].get_max_length(); + auto num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); + + // number of channels has to be divisible by number of groups + if (num_channels % num_groups != 0) + return false; + auto channels_to_groups_ratio = num_channels / num_groups; + + // MVN input has to have at least two dimensions: + // (batch_size, num_groups, ...) + if (pre_mvn_reshape_out_ps.size() < 2) + return false; + + // first dimension of MVN input (batch_size) has to be the same + // as in pattern input + if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length()) + return false; + + auto post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m); + auto post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape(); + + // post instance norm shape has to be same as in pattern input: + // (batch_size, num_channels, height, width) + if (post_instance_norm_reshape_out_ps != input_ps) + return false; + + auto group_norm_gamma = pattern_map.at(group_norm_gamma_m); + auto group_norm_gamma_ps = group_norm_gamma.get_partial_shape(); + + // group_norm_gamma has to share the same data type as + // pattern input + if (group_norm_gamma.get_element_type() != T) + return false; + + // group_norm_gamma has to be static + if (group_norm_gamma_ps.is_dynamic()) + return false; + + // number of elements in group_norm_gamma must be equal to + // number of channels + if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels) + return false; + + auto group_norm_beta = pattern_map.at(group_norm_beta_m); + auto group_norm_beta_ps = group_norm_beta.get_partial_shape(); + + // group_norm_beta has to share the same data type as + // pattern input + if (group_norm_beta.get_element_type() != T) + return false; + + // group_norm_beta has to be static + if (group_norm_beta_ps.is_dynamic()) + return false; + + // number of elements in group_norm_beta must be equal to + // number of channels + if (ov::shape_size(group_norm_beta.get_shape()) != num_channels) + return false; + + auto expected_param_shape = ov::PartialShape({num_channels}); + + std::shared_ptr group_norm_gamma_1d_m = std::make_shared(group_norm_gamma); + auto group_norm_gamma_1d_out = group_norm_gamma_1d_m->get_default_output(); + auto group_norm_gamma_1d_out_ps = group_norm_gamma_1d_out.get_partial_shape(); + + if (group_norm_gamma_1d_out_ps != expected_param_shape) + return false; + + std::shared_ptr group_norm_beta_1d_m = std::make_shared(group_norm_beta); + auto group_norm_beta_1d_out = group_norm_beta_1d_m->get_default_output(); + auto group_norm_beta_1d_out_ps = group_norm_beta_1d_out.get_partial_shape(); + + if (group_norm_beta_1d_out_ps != expected_param_shape) + return false; + + std::shared_ptr instance_norm_beta_1d_m = nullptr; + if (pattern_map.count(instance_norm_beta_m) > 0) { + auto instance_norm_beta = pattern_map.at(instance_norm_beta_m); + auto instance_norm_beta_ps = group_norm_beta.get_partial_shape(); + + // instance_norm_beta has to share the same data type as + // pattern input + if (instance_norm_beta.get_element_type() != T) + return false; + + // instance_norm_beta has to be static + if (instance_norm_beta_ps.is_dynamic()) + return false; + + // number of elements in instance_norm_beta must be equal to + // number of groups + if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups) + return false; + + // ensure that instance_norm_beta will have shape compatible + // with group_norm parameters, i.e. 1D vector of shape (num_channels) + if (ov::shape_size(instance_norm_beta.get_shape()) == 1) { + auto shape_1d_const_m = op::v0::Constant::create(element::i64, Shape{1}, {1}); + instance_norm_beta_1d_m = + std::make_shared(instance_norm_beta, shape_1d_const_m, true); + } else { + instance_norm_beta_1d_m = std::make_shared(instance_norm_beta); + } + ov::OutputVector instance_norm_beta_concat_inputs; + for (auto i = 0; i < channels_to_groups_ratio; i++) + instance_norm_beta_concat_inputs.push_back(instance_norm_beta_1d_m); + instance_norm_beta_1d_m = std::make_shared(instance_norm_beta_concat_inputs, 0); + auto instance_norm_beta_1d_out = instance_norm_beta_1d_m->get_default_output(); + auto instance_norm_beta_1d_ps = instance_norm_beta_1d_out.get_partial_shape(); + if (instance_norm_beta_1d_ps != expected_param_shape) + return false; + } + + if (pattern_map.count(instance_norm_gamma_m) > 0) { + auto instance_norm_gamma = pattern_map.at(instance_norm_gamma_m); + auto instance_norm_gamma_ps = group_norm_beta.get_partial_shape(); + + // instance_norm_gamma has to share the same data type as + // pattern input + if (instance_norm_gamma.get_element_type() != T) + return false; + + // instance_norm_gamma has to be static + if (instance_norm_gamma_ps.is_dynamic()) + return false; + + // number of elements in instance_norm_gamma must be equal to + // number of groups + if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups) + return false; + + // ensure that instance_norm_gamma will have shape compatible + // with group_norm parameters, i.e. 1D vector of shape (num_channels) + std::shared_ptr instance_norm_gamma_1d_m = nullptr; + if (ov::shape_size(instance_norm_gamma.get_shape()) == 1) { + auto shape_1d_const_m = op::v0::Constant::create(element::i64, Shape{1}, {1}); + instance_norm_gamma_1d_m = + std::make_shared(instance_norm_gamma, shape_1d_const_m, true); + } else { + instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma); + } + ov::OutputVector instance_norm_gamma_concat_inputs; + for (auto i = 0; i < channels_to_groups_ratio; i++) + instance_norm_gamma_concat_inputs.push_back(instance_norm_gamma_1d_m); + instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma_concat_inputs, 0); + auto instance_norm_gamma_1d_out = instance_norm_gamma_1d_m->get_default_output(); + auto instance_norm_gamma_1d_ps = instance_norm_gamma_1d_out.get_partial_shape(); + if (instance_norm_gamma_1d_ps != expected_param_shape) + return false; + + // group_norm_gamma /= instance_norm_gamma + group_norm_gamma_1d_m = + std::make_shared(group_norm_gamma_1d_m, instance_norm_gamma_1d_m); + + if (pattern_map.count(instance_norm_beta_m) > 0) { + // group_norm_beta -= group_norm_gamma * instance_norm_beta / instance_norm_gamma + auto group_norm_beta_corr_multiply_m = + std::make_shared(group_norm_gamma_1d_m, instance_norm_beta_1d_m); + auto group_norm_beta_corr_divide_m = + std::make_shared(group_norm_beta_corr_multiply_m, instance_norm_gamma_1d_m); + group_norm_beta_1d_m = + std::make_shared(group_norm_beta_1d_m, group_norm_beta_corr_divide_m); + } + } else { + if (pattern_map.count(instance_norm_beta_m) > 0) { + // group_norm_beta -= group_norm_gamma * instance_norm_beta + auto group_norm_beta_corr_multiply_m = + std::make_shared(group_norm_gamma_1d_m, instance_norm_beta_1d_m); + group_norm_beta_1d_m = + std::make_shared(group_norm_beta_1d_m, group_norm_beta_corr_multiply_m); + } + } + + // we need to be able to cast mvn to MVN layer type + // in order to read actual epsilon value + auto mvn_out = pattern_map.at(mvn_m); + auto mvn = std::dynamic_pointer_cast(mvn_out.get_node_shared_ptr()); + auto epsilon = mvn->get_eps(); + + // we can finally create GroupNormalization op + std::shared_ptr group_norm = std::make_shared(input, + group_norm_gamma_1d_m, + group_norm_beta_1d_m, + num_groups, + epsilon); + + // and do actual graph substitution + group_norm->set_friendly_name(m.get_match_root()->get_friendly_name()); + ov::copy_runtime_info(m.get_matched_nodes(), group_norm); + ov::replace_node(m.get_match_root(), group_norm); + return true; + }; + + auto m = std::make_shared(group_norm_beta_add_m, matcher_name); + this->register_matcher(m, callback); +} From eeba957d37393b3d7990e69b4f06d63e45510794 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Sat, 11 Jan 2025 01:27:16 +0000 Subject: [PATCH 02/45] Add GroupNormalization fusion tests --- .../group_normalization_fusion_tests.cpp | 502 ++++++++++++++++++ 1 file changed, 502 insertions(+) create mode 100644 src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp new file mode 100644 index 00000000000000..2bc76ad44ff230 --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -0,0 +1,502 @@ +#include + +#include "common_test_utils/data_utils.hpp" +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/model.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/group_normalization.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/mvn.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/serialize.hpp" +#include "transformations/common_optimizations/group_normalization_fusion.hpp" +#include "transformations/init_node_info.hpp" + +using namespace testing; +using namespace ov; + +class GroupNormalizationFusionValueParametrizedTestsFixture + : public ::testing::TestWithParam< + std::tuple> {}; + +TEST_P(GroupNormalizationFusionValueParametrizedTestsFixture, GroupNormalizationFusionTestValueParametrizedTests) { + auto params = GetParam(); + typedef ov::float16 T_act_t; + constexpr auto T_act_elem_t = element::from(); + typedef ov::element_type_traits::value_type T_act_store_t; + auto positive_test = std::get<0>(params); + auto data_shape = std::get<1>(params); + ASSERT_TRUE(data_shape[1].is_static()); + auto num_channels = static_cast(data_shape[1].get_max_length()); + auto instance_norm_gamma_shape = std::get<2>(params); + auto instance_norm_beta_shape = std::get<3>(params); + auto group_norm_gamma_shape = std::get<4>(params); + auto group_norm_beta_shape = std::get<5>(params); + auto num_groups = std::get<6>(params); + auto epsilon = std::get<7>(params); + + if (positive_test) { + if ((instance_norm_gamma_shape != Shape{}) && (shape_size(instance_norm_gamma_shape) != num_groups)) + FAIL() + << "Unexpected shape of instance norm beta - expected either empty shape (which means that it will not " + "be put in the graph) or shape with exactly num_groups elements that can be merged with the result " + "of MVN."; + + if ((instance_norm_beta_shape != Shape{}) && (shape_size(instance_norm_beta_shape) != num_groups)) + FAIL() + << "Unexpected shape of instance norm beta - expected either empty shape (which means that it will not " + "be put in the graph) or shape with exactly num_groups elements that can be merged with the result " + "of MVN."; + + if ((group_norm_gamma_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) + FAIL() + << "Unexpected shape of group norm gamma - expected either empty shape (which means that it will not " + "be put in the graph) or shape with exactly num_channels elements that can be merged with the " + "result " + "of instance norm."; + + if ((group_norm_beta_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) + FAIL() << "Unexpected shape of group norm beta - expected either empty shape (which means that it will not " + "be put in the graph) or shape with exactly num_channels elements that can be merged with the " + "result " + "of instance norm."; + } + auto instance_norm_gamma_present = (instance_norm_gamma_shape != Shape{}); + auto instance_norm_beta_present = (instance_norm_beta_shape != Shape{}); + auto group_norm_beta_present = (group_norm_beta_shape != Shape{}); + auto group_norm_gamma_present = (group_norm_gamma_shape != Shape{}); + + if (positive_test) { + instance_norm_gamma_present = + instance_norm_gamma_present && (shape_size(instance_norm_gamma_shape) == num_groups); + instance_norm_beta_present = instance_norm_beta_present && (shape_size(instance_norm_beta_shape) == num_groups); + group_norm_beta_present = group_norm_beta_present && (shape_size(group_norm_beta_shape) == num_channels); + group_norm_gamma_present = group_norm_gamma_present && (shape_size(group_norm_gamma_shape) == num_channels); + } + + auto instance_norm_gamma_vals = std::vector(); + if (instance_norm_gamma_present) + instance_norm_gamma_vals = test::utils::generateVector(shape_size(instance_norm_gamma_shape)); + + auto instance_norm_beta_vals = std::vector(); + if (instance_norm_beta_present) + instance_norm_beta_vals = test::utils::generateVector(shape_size(instance_norm_beta_shape)); + + auto group_norm_gamma_vals = std::vector(); + if (group_norm_gamma_present) + group_norm_gamma_vals = test::utils::generateVector(shape_size(group_norm_gamma_shape)); + + auto group_norm_beta_vals = std::vector(); + if (group_norm_beta_present) + group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); + + std::shared_ptr model(nullptr), model_ref(nullptr); + { + auto input = std::make_shared(T_act_elem_t, data_shape); + auto pre_mvn_shape_const = + op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(num_groups), -1}); + auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); + + auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {1}); + auto mvn = + std::make_shared(pre_mvn_reshape, mvn_axes_const, true, epsilon, op::MVNEpsMode::INSIDE_SQRT); + + std::shared_ptr opt_instance_norm_gamma_multiply = mvn; + if (instance_norm_gamma_present) { + auto instance_norm_gamma_const = + op::v0::Constant::create(T_act_elem_t, instance_norm_gamma_shape, instance_norm_gamma_vals); + opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); + } + + std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; + if (instance_norm_beta_present) { + auto instance_norm_beta_const = + op::v0::Constant::create(T_act_elem_t, instance_norm_beta_shape, instance_norm_beta_vals); + opt_instance_norm_beta_add = + std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); + } + + auto post_instance_norm_shape = std::make_shared(input); + + auto post_instance_norm_reshape = + std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); + + std::shared_ptr opt_group_norm_gamma_multiply = post_instance_norm_reshape; + if (group_norm_gamma_present) { + auto group_norm_gamma_const = + op::v0::Constant::create(T_act_elem_t, group_norm_gamma_shape, group_norm_gamma_vals); + opt_group_norm_gamma_multiply = + std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); + } + + std::shared_ptr opt_group_norm_beta_add = opt_group_norm_gamma_multiply; + if (group_norm_beta_present) { + auto group_norm_beta_const = + op::v0::Constant::create(T_act_elem_t, group_norm_beta_shape, group_norm_beta_vals); + opt_group_norm_beta_add = + std::make_shared(opt_group_norm_gamma_multiply, group_norm_beta_const); + } + + model = std::make_shared(NodeVector{opt_group_norm_beta_add}, ParameterVector{input}); + + pass::Manager m; + m.register_pass(); + OV_ASSERT_NO_THROW(m.run_passes(model)); + } + + if (positive_test) { + auto input = std::make_shared(T_act_elem_t, data_shape); + + std::shared_ptr group_norm_beta_1d = nullptr; + std::shared_ptr group_norm_gamma_1d = nullptr; + + if (instance_norm_gamma_present) { + if (!group_norm_gamma_present) + group_norm_gamma_vals = std::vector(num_channels, 1); + auto group_norm_gamma_corr_vals = group_norm_gamma_vals; + for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + group_norm_gamma_corr_vals[i] /= instance_norm_gamma_vals[i % num_groups]; + group_norm_gamma_1d = + op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_gamma_corr_vals); + if (instance_norm_beta_present) { + if (!group_norm_beta_present) + group_norm_beta_vals = std::vector(num_channels, 0); + auto group_norm_beta_corr_vals = group_norm_beta_vals; + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] -= + (group_norm_gamma_corr_vals[i] * instance_norm_beta_vals[i % num_groups]) / + instance_norm_gamma_vals[i % num_groups]; + group_norm_beta_1d = + op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + } + } else { + if (instance_norm_beta_present) { + if (!group_norm_beta_present) + group_norm_beta_vals = std::vector(num_channels, 0); + auto group_norm_beta_corr_vals = group_norm_beta_vals; + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] -= group_norm_gamma_vals[i] * instance_norm_beta_vals[i % num_groups]; + group_norm_beta_1d = + op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + } + } + + if (group_norm_gamma_present) { + if (group_norm_gamma_1d == nullptr) { + group_norm_gamma_1d = + op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_gamma_vals); + } + } else { + group_norm_gamma_1d = op::v0::Constant::create(T_act_elem_t, + Shape{num_channels}, + std::vector(num_channels, 1)); + } + + if (group_norm_beta_present) { + if (group_norm_beta_1d == nullptr) { + group_norm_beta_1d = op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_beta_vals); + } + } else { + group_norm_beta_1d = op::v0::Constant::create(T_act_elem_t, + Shape{num_channels}, + std::vector(num_channels, 0)); + } + + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + num_groups, + epsilon); + + model_ref = std::make_shared(NodeVector{group_norm}, ParameterVector{input}); + } + + if (positive_test) { + ASSERT_EQ(count_ops_of_type(model), 1); + auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::ACCURACY); + auto res = fc.compare(model, model_ref); + ASSERT_TRUE(res.valid) << res.message; + } else { + ASSERT_EQ(count_ops_of_type(model), 0); + } +} + +INSTANTIATE_TEST_SUITE_P( + GroupNormalizationFusionValueParametrizedPositiveTests, + GroupNormalizationFusionValueParametrizedTestsFixture, + ::testing::Values( + std::make_tuple(true, Shape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), + std::make_tuple(true, + Shape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{320, 1, 1}, + Shape{1, 320, 1, 1}, + 1, + 1e-5f), + std::make_tuple(true, + Shape{1, 320, 2, 2}, + Shape{1, 320, 1}, + Shape{1, 320, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 320, + 1e-5f), + std::make_tuple(true, + PartialShape{Dimension::dynamic(), 320, Dimension::dynamic(), Dimension::dynamic()}, + Shape{1, 320, 1}, + Shape{1, 320, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 320, + 1e-5f), + std::make_tuple(true, + PartialShape{Dimension::dynamic(), 320}, + Shape{32, 1}, + Shape{32, 1}, + Shape{320}, + Shape{320}, + 32, + 1e-5f), + std::make_tuple(true, + PartialShape{1, 320, Dimension::dynamic()}, + Shape{32, 1}, + Shape{32, 1}, + Shape{320, 1}, + Shape{320, 1}, + 32, + 1e-5f), + std::make_tuple(true, + PartialShape{1, 320, 2, Dimension::dynamic()}, + Shape{1, 32, 1}, + Shape{1, 32, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(true, Shape{2, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), + std::make_tuple(true, + PartialShape{1, 512, Dimension::dynamic(), Dimension::dynamic()}, + Shape{}, + Shape{1, 128, 1}, + Shape{1, 512, 1, 1}, + Shape{512, 1, 1}, + 128, + 1e-6f), + std::make_tuple(true, + Shape{1, 512, 2, 2}, + Shape{1, 64, 1}, + Shape{}, + Shape{1, 512, 1, 1}, + Shape{1, 512, 1, 1}, + 64, + 1e-6f))); + +INSTANTIATE_TEST_SUITE_P( + GroupNormalizationFusionValueParametrizedNegativeTests, + GroupNormalizationFusionValueParametrizedTestsFixture, + ::testing::Values( + std::make_tuple(false, Shape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), + std::make_tuple(false, + Shape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1, 1}, + 1, + 1e-5f), + std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), + std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), + std::make_tuple(false, + Shape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 32, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(false, + Shape{1, 320, 2, 2}, + Shape{1, 32, 1}, + Shape{1, 1, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(false, + PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, + Shape{}, + Shape{}, + Shape{1, 512, 1, 1}, + Shape{1, 512, 1, 1}, + 100, + 1e-6f))); + +template +class GroupNormalizationFusionTestMultiType { +public: + constexpr static bool positive_test = positive_test; + using T_act_t = T_act; + using T_gn_gamma_t = T_gn_gamma; + using T_gn_beta_t = T_gn_beta; + using T_in_gamma_t = T_in_gamma; + using T_in_beta_t = T_in_beta; +}; + +template +class GroupNormalizationFusionTypeParametrizedTestsFixture : public ::testing::Test {}; + +using GroupNormalizationFusionPositiveTestTypes = + ::testing::Types, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType>; + +using GroupNormalizationFusionNegativeTestTypes = + ::testing::Types, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType, + GroupNormalizationFusionTestMultiType>; + +TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedTestsFixture); + +TYPED_TEST_P(GroupNormalizationFusionTypeParametrizedTestsFixture, GroupNormalizationFusionTypeParametrizedTests) { + constexpr bool positive_test = TypeParam::positive_test; + + typedef TypeParam::T_act_t T_act_t; + typedef TypeParam::T_gn_gamma_t T_gn_gamma_t; + typedef TypeParam::T_gn_beta_t T_gn_beta_t; + typedef TypeParam::T_in_gamma_t T_in_gamma_t; + typedef TypeParam::T_in_beta_t T_in_beta_t; + + constexpr auto T_act_elem_t = element::from(); + constexpr auto T_gn_gamma_elem_t = element::from(); + constexpr auto T_gn_beta_elem_t = element::from(); + constexpr auto T_in_gamma_elem_t = element::from(); + constexpr auto T_in_beta_elem_t = element::from(); + + typedef ov::element_type_traits::value_type T_act_store_t; + typedef ov::element_type_traits::value_type T_gn_gamma_store_t; + typedef ov::element_type_traits::value_type T_gn_beta_store_t; + typedef ov::element_type_traits::value_type T_in_gamma_store_t; + typedef ov::element_type_traits::value_type T_in_beta_store_t; + + auto data_shape = Shape{1, 320, 2, 2}; + auto instance_norm_gamma_shape = Shape{1, 32, 1}; + auto instance_norm_beta_shape = Shape{1, 32, 1}; + auto group_norm_gamma_shape = Shape{1, 320, 1, 1}; + auto group_norm_beta_shape = Shape{1, 320, 1, 1}; + + auto num_channels = 320ull; + auto num_groups = 32; + auto epsilon = 1e-5f; + + auto instance_norm_gamma_vals = + test::utils::generateVector(shape_size(instance_norm_gamma_shape)); + auto instance_norm_beta_vals = test::utils::generateVector(shape_size(instance_norm_beta_shape)); + auto group_norm_gamma_vals = test::utils::generateVector(shape_size(group_norm_gamma_shape)); + auto group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); + + std::shared_ptr model(nullptr), model_ref(nullptr); + { + auto input = std::make_shared(T_act_elem_t, data_shape); + auto pre_mvn_shape_const = + op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(num_groups), -1}); + auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); + + auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {1}); + auto mvn = + std::make_shared(pre_mvn_reshape, mvn_axes_const, true, epsilon, op::MVNEpsMode::INSIDE_SQRT); + + auto instance_norm_gamma_const = + op::v0::Constant::create(T_in_gamma_elem_t, instance_norm_gamma_shape, instance_norm_gamma_vals); + auto instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); + + auto instance_norm_beta_const = + op::v0::Constant::create(T_in_beta_elem_t, instance_norm_beta_shape, instance_norm_beta_vals); + auto instance_norm_beta_add = + std::make_shared(instance_norm_gamma_multiply, instance_norm_beta_const); + + auto post_instance_norm_shape = std::make_shared(input); + + auto post_instance_norm_reshape = + std::make_shared(instance_norm_beta_add, post_instance_norm_shape, true); + + auto group_norm_gamma_const = + op::v0::Constant::create(T_gn_gamma_elem_t, group_norm_gamma_shape, group_norm_gamma_vals); + auto group_norm_gamma_multiply = + std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); + + auto group_norm_beta_const = + op::v0::Constant::create(T_gn_beta_elem_t, group_norm_beta_shape, group_norm_beta_vals); + auto group_norm_beta_add = std::make_shared(group_norm_gamma_multiply, group_norm_beta_const); + + model = std::make_shared(NodeVector{group_norm_beta_add}, ParameterVector{input}); + + pass::Manager m; + m.register_pass(); + OV_ASSERT_NO_THROW(m.run_passes(model)); + } + + if (positive_test) { + auto input = std::make_shared(T_act_elem_t, data_shape); + + auto group_norm_gamma_corr_vals = group_norm_gamma_vals; + + for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + group_norm_gamma_corr_vals[i] /= instance_norm_gamma_vals[i % num_groups]; + + auto group_norm_gamma_1d = + op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_corr_vals); + + auto group_norm_beta_corr_vals = group_norm_beta_vals; + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] -= (group_norm_gamma_corr_vals[i] * instance_norm_beta_vals[i % num_groups]) / + instance_norm_gamma_vals[i % num_groups]; + auto group_norm_beta_1d = + op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + num_groups, + epsilon); + + model_ref = std::make_shared(NodeVector{group_norm}, ParameterVector{input}); + } + + if (positive_test) { + ASSERT_EQ(count_ops_of_type(model), 1); + auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::ACCURACY); + auto res = fc.compare(model, model_ref); + ASSERT_TRUE(res.valid) << res.message; + } else { + ASSERT_EQ(count_ops_of_type(model), 0); + } +} + +REGISTER_TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedTestsFixture, + GroupNormalizationFusionTypeParametrizedTests); + +INSTANTIATE_TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedPositiveTests, + GroupNormalizationFusionTypeParametrizedTestsFixture, + GroupNormalizationFusionPositiveTestTypes); + +INSTANTIATE_TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedNegativeTests, + GroupNormalizationFusionTypeParametrizedTestsFixture, + GroupNormalizationFusionNegativeTestTypes); \ No newline at end of file From 67f3af3dde5dc5b6e22d96dcefcdcd017bb90c94 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Sat, 11 Jan 2025 01:28:15 +0000 Subject: [PATCH 03/45] Enable GroupNormalization fusion pass in GPU plugin --- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index f6bf70c77c1efc..1a37abd03685ec 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -93,6 +93,7 @@ #include "transformations/common_optimizations/broadcast_transition.hpp" #include "transformations/common_optimizations/common_optimizations.hpp" #include "transformations/common_optimizations/convert_quantize_dequantize.hpp" +#include "transformations/common_optimizations/group_normalization_fusion.hpp" #include "transformations/common_optimizations/lin_op_sequence_fusion.hpp" #include "transformations/common_optimizations/lstm_cell_fusion.hpp" #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" @@ -340,6 +341,13 @@ void TransformationsPipeline::apply(std::shared_ptr func) { auto pass_config = manager.get_pass_config(); manager.set_per_pass_validation(false); + // fuse following ops into GroupNormalization: + // group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta + // note that instance norm related parameters are optional: + // - instance_norm_gamma is assumed to be filled with ones if not present in the graph + // - instance_norm_beta is assumed to be filled with zeros if not present in the graph + manager.register_pass(); + // Temporary solution, global rt info cleanup is needed for (auto& node : func->get_ops()) { ov::enable_constant_folding(node); From e15b97fb646e85adca96fc8085e20eeb4bd07586 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Mon, 13 Jan 2025 11:47:09 +0100 Subject: [PATCH 04/45] Update copyright notice --- .../common_optimizations/group_normalization_fusion.hpp | 2 +- .../common_optimizations/group_normalization_fusion.cpp | 2 +- .../common_optimizations/group_normalization_fusion_tests.cpp | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp index d7ad56946295c7..eb1977583b9654 100644 --- a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 50a04c459b989e..4d37ca650cd619 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 2bc76ad44ff230..5f5489bb5d9184 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -1,3 +1,7 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + #include #include "common_test_utils/data_utils.hpp" From f027f01f7a0821a38ebe26f94512b17aa29b552e Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 14 Jan 2025 14:04:38 +0100 Subject: [PATCH 05/45] Refactor GroupNormalizationFusion tests to avoid changes in core API --- .../group_normalization_fusion_tests.cpp | 950 ++++++++++-------- 1 file changed, 507 insertions(+), 443 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 5f5489bb5d9184..aea795a0826f31 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -22,485 +22,549 @@ using namespace testing; using namespace ov; -class GroupNormalizationFusionValueParametrizedTestsFixture +template +class GroupNormalizationFusionTestsFixture : public ::testing::TestWithParam< - std::tuple> {}; - -TEST_P(GroupNormalizationFusionValueParametrizedTestsFixture, GroupNormalizationFusionTestValueParametrizedTests) { - auto params = GetParam(); - typedef ov::float16 T_act_t; - constexpr auto T_act_elem_t = element::from(); - typedef ov::element_type_traits::value_type T_act_store_t; - auto positive_test = std::get<0>(params); - auto data_shape = std::get<1>(params); - ASSERT_TRUE(data_shape[1].is_static()); - auto num_channels = static_cast(data_shape[1].get_max_length()); - auto instance_norm_gamma_shape = std::get<2>(params); - auto instance_norm_beta_shape = std::get<3>(params); - auto group_norm_gamma_shape = std::get<4>(params); - auto group_norm_beta_shape = std::get<5>(params); - auto num_groups = std::get<6>(params); - auto epsilon = std::get<7>(params); - - if (positive_test) { - if ((instance_norm_gamma_shape != Shape{}) && (shape_size(instance_norm_gamma_shape) != num_groups)) - FAIL() - << "Unexpected shape of instance norm beta - expected either empty shape (which means that it will not " - "be put in the graph) or shape with exactly num_groups elements that can be merged with the result " - "of MVN."; - - if ((instance_norm_beta_shape != Shape{}) && (shape_size(instance_norm_beta_shape) != num_groups)) - FAIL() - << "Unexpected shape of instance norm beta - expected either empty shape (which means that it will not " - "be put in the graph) or shape with exactly num_groups elements that can be merged with the result " - "of MVN."; - - if ((group_norm_gamma_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) - FAIL() - << "Unexpected shape of group norm gamma - expected either empty shape (which means that it will not " - "be put in the graph) or shape with exactly num_channels elements that can be merged with the " - "result " - "of instance norm."; - - if ((group_norm_beta_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) - FAIL() << "Unexpected shape of group norm beta - expected either empty shape (which means that it will not " - "be put in the graph) or shape with exactly num_channels elements that can be merged with the " - "result " - "of instance norm."; - } - auto instance_norm_gamma_present = (instance_norm_gamma_shape != Shape{}); - auto instance_norm_beta_present = (instance_norm_beta_shape != Shape{}); - auto group_norm_beta_present = (group_norm_beta_shape != Shape{}); - auto group_norm_gamma_present = (group_norm_gamma_shape != Shape{}); - - if (positive_test) { - instance_norm_gamma_present = - instance_norm_gamma_present && (shape_size(instance_norm_gamma_shape) == num_groups); - instance_norm_beta_present = instance_norm_beta_present && (shape_size(instance_norm_beta_shape) == num_groups); - group_norm_beta_present = group_norm_beta_present && (shape_size(group_norm_beta_shape) == num_channels); - group_norm_gamma_present = group_norm_gamma_present && (shape_size(group_norm_gamma_shape) == num_channels); - } - - auto instance_norm_gamma_vals = std::vector(); - if (instance_norm_gamma_present) - instance_norm_gamma_vals = test::utils::generateVector(shape_size(instance_norm_gamma_shape)); - - auto instance_norm_beta_vals = std::vector(); - if (instance_norm_beta_present) - instance_norm_beta_vals = test::utils::generateVector(shape_size(instance_norm_beta_shape)); - - auto group_norm_gamma_vals = std::vector(); - if (group_norm_gamma_present) - group_norm_gamma_vals = test::utils::generateVector(shape_size(group_norm_gamma_shape)); - - auto group_norm_beta_vals = std::vector(); - if (group_norm_beta_present) - group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); - - std::shared_ptr model(nullptr), model_ref(nullptr); - { - auto input = std::make_shared(T_act_elem_t, data_shape); - auto pre_mvn_shape_const = - op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(num_groups), -1}); - auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); - - auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {1}); - auto mvn = - std::make_shared(pre_mvn_reshape, mvn_axes_const, true, epsilon, op::MVNEpsMode::INSIDE_SQRT); - - std::shared_ptr opt_instance_norm_gamma_multiply = mvn; - if (instance_norm_gamma_present) { - auto instance_norm_gamma_const = - op::v0::Constant::create(T_act_elem_t, instance_norm_gamma_shape, instance_norm_gamma_vals); - opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); + std::tuple> { +public: + static constexpr element::Type_t T_act_elem_t = T_act_elem; + static constexpr element::Type_t T_gn_gamma_elem_t = T_gn_gamma_elem; + static constexpr element::Type_t T_gn_beta_elem_t = T_gn_beta_elem; + static constexpr element::Type_t T_in_gamma_elem_t = T_in_gamma_elem; + static constexpr element::Type_t T_in_beta_elem_t = T_in_beta_elem; + + typedef typename ov::element_type_traits::value_type T_act_store_t; + typedef typename ov::element_type_traits::value_type T_gn_gamma_store_t; + typedef typename ov::element_type_traits::value_type T_gn_beta_store_t; + typedef typename ov::element_type_traits::value_type T_in_gamma_store_t; + typedef typename ov::element_type_traits::value_type T_in_beta_store_t; + + virtual void TestBody() { + auto params = GetParam(); + auto positive_test = std::get<0>(params); + auto data_shape = std::get<1>(params); + ASSERT_TRUE(data_shape[1].is_static()); + auto num_channels = static_cast(data_shape[1].get_max_length()); + auto instance_norm_gamma_shape = std::get<2>(params); + auto instance_norm_beta_shape = std::get<3>(params); + auto group_norm_gamma_shape = std::get<4>(params); + auto group_norm_beta_shape = std::get<5>(params); + auto num_groups = std::get<6>(params); + auto epsilon = std::get<7>(params); + + if (positive_test) { + if ((instance_norm_gamma_shape != Shape{}) && (shape_size(instance_norm_gamma_shape) != num_groups)) + FAIL() << "Unexpected shape of instance norm beta - expected either empty shape (which means that it " + "will not " + "be put in the graph) or shape with exactly num_groups elements that can be merged with the " + "result " + "of MVN."; + + if ((instance_norm_beta_shape != Shape{}) && (shape_size(instance_norm_beta_shape) != num_groups)) + FAIL() << "Unexpected shape of instance norm beta - expected either empty shape (which means that it " + "will not " + "be put in the graph) or shape with exactly num_groups elements that can be merged with the " + "result " + "of MVN."; + + if ((group_norm_gamma_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) + FAIL() + << "Unexpected shape of group norm gamma - expected either empty shape (which means that it will " + "not " + "be put in the graph) or shape with exactly num_channels elements that can be merged with the " + "result " + "of instance norm."; + + if ((group_norm_beta_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) + FAIL() + << "Unexpected shape of group norm beta - expected either empty shape (which means that it will " + "not " + "be put in the graph) or shape with exactly num_channels elements that can be merged with the " + "result " + "of instance norm."; } - - std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; - if (instance_norm_beta_present) { - auto instance_norm_beta_const = - op::v0::Constant::create(T_act_elem_t, instance_norm_beta_shape, instance_norm_beta_vals); - opt_instance_norm_beta_add = - std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); + auto instance_norm_gamma_present = (instance_norm_gamma_shape != Shape{}); + auto instance_norm_beta_present = (instance_norm_beta_shape != Shape{}); + auto group_norm_beta_present = (group_norm_beta_shape != Shape{}); + auto group_norm_gamma_present = (group_norm_gamma_shape != Shape{}); + + if (positive_test) { + instance_norm_gamma_present = + instance_norm_gamma_present && (shape_size(instance_norm_gamma_shape) == num_groups); + instance_norm_beta_present = + instance_norm_beta_present && (shape_size(instance_norm_beta_shape) == num_groups); + group_norm_beta_present = group_norm_beta_present && (shape_size(group_norm_beta_shape) == num_channels); + group_norm_gamma_present = group_norm_gamma_present && (shape_size(group_norm_gamma_shape) == num_channels); } - auto post_instance_norm_shape = std::make_shared(input); + auto instance_norm_gamma_vals = std::vector(); + if (instance_norm_gamma_present) + instance_norm_gamma_vals = + test::utils::generateVector(shape_size(instance_norm_gamma_shape)); + + auto instance_norm_beta_vals = std::vector(); + if (instance_norm_beta_present) + instance_norm_beta_vals = + test::utils::generateVector(shape_size(instance_norm_beta_shape)); + + auto group_norm_gamma_vals = std::vector(); + if (group_norm_gamma_present) + group_norm_gamma_vals = test::utils::generateVector(shape_size(group_norm_gamma_shape)); + + auto group_norm_beta_vals = std::vector(); + if (group_norm_beta_present) + group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); + + std::shared_ptr model(nullptr), model_ref(nullptr); + { + auto input = std::make_shared(T_act_elem_t, data_shape); + auto pre_mvn_shape_const = op::v0::Constant::create(element::i64, + Shape{3}, + {0, static_cast(num_groups), -1}); + auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); + + auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {1}); + auto mvn = std::make_shared(pre_mvn_reshape, + mvn_axes_const, + true, + epsilon, + op::MVNEpsMode::INSIDE_SQRT); + + std::shared_ptr opt_instance_norm_gamma_multiply = mvn; + if (instance_norm_gamma_present) { + auto instance_norm_gamma_const = + op::v0::Constant::create(T_in_gamma_elem_t, instance_norm_gamma_shape, instance_norm_gamma_vals); + opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); + } - auto post_instance_norm_reshape = - std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); + std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; + if (instance_norm_beta_present) { + auto instance_norm_beta_const = + op::v0::Constant::create(T_in_beta_elem_t, instance_norm_beta_shape, instance_norm_beta_vals); + opt_instance_norm_beta_add = + std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); + } - std::shared_ptr opt_group_norm_gamma_multiply = post_instance_norm_reshape; - if (group_norm_gamma_present) { - auto group_norm_gamma_const = - op::v0::Constant::create(T_act_elem_t, group_norm_gamma_shape, group_norm_gamma_vals); - opt_group_norm_gamma_multiply = - std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); - } + auto post_instance_norm_shape = std::make_shared(input); - std::shared_ptr opt_group_norm_beta_add = opt_group_norm_gamma_multiply; - if (group_norm_beta_present) { - auto group_norm_beta_const = - op::v0::Constant::create(T_act_elem_t, group_norm_beta_shape, group_norm_beta_vals); - opt_group_norm_beta_add = - std::make_shared(opt_group_norm_gamma_multiply, group_norm_beta_const); - } + auto post_instance_norm_reshape = + std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); - model = std::make_shared(NodeVector{opt_group_norm_beta_add}, ParameterVector{input}); + std::shared_ptr opt_group_norm_gamma_multiply = post_instance_norm_reshape; + if (group_norm_gamma_present) { + auto group_norm_gamma_const = + op::v0::Constant::create(T_gn_gamma_elem_t, group_norm_gamma_shape, group_norm_gamma_vals); + opt_group_norm_gamma_multiply = + std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); + } - pass::Manager m; - m.register_pass(); - OV_ASSERT_NO_THROW(m.run_passes(model)); - } + std::shared_ptr opt_group_norm_beta_add = opt_group_norm_gamma_multiply; + if (group_norm_beta_present) { + auto group_norm_beta_const = + op::v0::Constant::create(T_gn_beta_elem_t, group_norm_beta_shape, group_norm_beta_vals); + opt_group_norm_beta_add = + std::make_shared(opt_group_norm_gamma_multiply, group_norm_beta_const); + } - if (positive_test) { - auto input = std::make_shared(T_act_elem_t, data_shape); + model = std::make_shared(NodeVector{opt_group_norm_beta_add}, ParameterVector{input}); - std::shared_ptr group_norm_beta_1d = nullptr; - std::shared_ptr group_norm_gamma_1d = nullptr; + pass::Manager m; + m.register_pass(); + OV_ASSERT_NO_THROW(m.run_passes(model)); + } - if (instance_norm_gamma_present) { - if (!group_norm_gamma_present) - group_norm_gamma_vals = std::vector(num_channels, 1); - auto group_norm_gamma_corr_vals = group_norm_gamma_vals; - for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) - group_norm_gamma_corr_vals[i] /= instance_norm_gamma_vals[i % num_groups]; - group_norm_gamma_1d = - op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_gamma_corr_vals); - if (instance_norm_beta_present) { - if (!group_norm_beta_present) - group_norm_beta_vals = std::vector(num_channels, 0); - auto group_norm_beta_corr_vals = group_norm_beta_vals; - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] -= - (group_norm_gamma_corr_vals[i] * instance_norm_beta_vals[i % num_groups]) / - instance_norm_gamma_vals[i % num_groups]; - group_norm_beta_1d = - op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + if (positive_test) { + auto input = std::make_shared(T_act_elem_t, data_shape); + + std::shared_ptr group_norm_beta_1d = nullptr; + std::shared_ptr group_norm_gamma_1d = nullptr; + + if (instance_norm_gamma_present) { + if (!group_norm_gamma_present) + group_norm_gamma_vals = std::vector(num_channels, 1); + auto group_norm_gamma_corr_vals = group_norm_gamma_vals; + for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + group_norm_gamma_corr_vals[i] /= instance_norm_gamma_vals[i % num_groups]; + group_norm_gamma_1d = + op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_corr_vals); + if (instance_norm_beta_present) { + if (!group_norm_beta_present) + group_norm_beta_vals = std::vector(num_channels, 0); + auto group_norm_beta_corr_vals = group_norm_beta_vals; + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] -= + (group_norm_gamma_corr_vals[i] * instance_norm_beta_vals[i % num_groups]) / + instance_norm_gamma_vals[i % num_groups]; + group_norm_beta_1d = + op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + } + } else { + if (instance_norm_beta_present) { + if (!group_norm_beta_present) + group_norm_beta_vals = std::vector(num_channels, 0); + auto group_norm_beta_corr_vals = group_norm_beta_vals; + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] -= + group_norm_gamma_vals[i] * instance_norm_beta_vals[i % num_groups]; + group_norm_beta_1d = + op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + } } - } else { - if (instance_norm_beta_present) { - if (!group_norm_beta_present) - group_norm_beta_vals = std::vector(num_channels, 0); - auto group_norm_beta_corr_vals = group_norm_beta_vals; - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] -= group_norm_gamma_vals[i] * instance_norm_beta_vals[i % num_groups]; - group_norm_beta_1d = - op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); + + if (group_norm_gamma_present) { + if (group_norm_gamma_1d == nullptr) { + group_norm_gamma_1d = + op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_vals); + } + } else { + group_norm_gamma_1d = op::v0::Constant::create(T_gn_gamma_elem_t, + Shape{num_channels}, + std::vector(num_channels, 1)); } - } - if (group_norm_gamma_present) { - if (group_norm_gamma_1d == nullptr) { - group_norm_gamma_1d = - op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_gamma_vals); + if (group_norm_beta_present) { + if (group_norm_beta_1d == nullptr) { + group_norm_beta_1d = + op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_vals); + } + } else { + group_norm_beta_1d = op::v0::Constant::create(T_gn_beta_elem_t, + Shape{num_channels}, + std::vector(num_channels, 0)); } - } else { - group_norm_gamma_1d = op::v0::Constant::create(T_act_elem_t, - Shape{num_channels}, - std::vector(num_channels, 1)); + + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + num_groups, + epsilon); + + model_ref = std::make_shared(NodeVector{group_norm}, ParameterVector{input}); } - if (group_norm_beta_present) { - if (group_norm_beta_1d == nullptr) { - group_norm_beta_1d = op::v0::Constant::create(T_act_elem_t, Shape{num_channels}, group_norm_beta_vals); - } + if (positive_test) { + ASSERT_EQ(count_ops_of_type(model), 1); + auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::ACCURACY); + auto res = fc.compare(model, model_ref); + ASSERT_TRUE(res.valid) << res.message; } else { - group_norm_beta_1d = op::v0::Constant::create(T_act_elem_t, - Shape{num_channels}, - std::vector(num_channels, 0)); + ASSERT_EQ(count_ops_of_type(model), 0); } + } +}; - auto group_norm = std::make_shared(input, - group_norm_gamma_1d, - group_norm_beta_1d, - num_groups, - epsilon); +class GroupNormalizationFusionTestsFixture_f16 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_bf16 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_f32 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_u8 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_u16 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_u32 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_u64 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_i8 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_i16 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_i32 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_i64 : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_f8e4m3 + : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_f8e5m2 + : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_f4e2m1 + : public GroupNormalizationFusionTestsFixture {}; +class GroupNormalizationFusionTestsFixture_f8e8m0 + : public GroupNormalizationFusionTestsFixture {}; + +TEST_P(GroupNormalizationFusionTestsFixture_f16, GroupNormalizationFusionTests_f16) { + GroupNormalizationFusionTestsFixture_f16::TestBody(); +} - model_ref = std::make_shared(NodeVector{group_norm}, ParameterVector{input}); - } +TEST_P(GroupNormalizationFusionTestsFixture_bf16, GroupNormalizationFusionTests_bf16) { + GroupNormalizationFusionTestsFixture_bf16::TestBody(); +} - if (positive_test) { - ASSERT_EQ(count_ops_of_type(model), 1); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::ACCURACY); - auto res = fc.compare(model, model_ref); - ASSERT_TRUE(res.valid) << res.message; - } else { - ASSERT_EQ(count_ops_of_type(model), 0); - } +TEST_P(GroupNormalizationFusionTestsFixture_f32, GroupNormalizationFusionTests_f32) { + GroupNormalizationFusionTestsFixture_f32::TestBody(); } -INSTANTIATE_TEST_SUITE_P( - GroupNormalizationFusionValueParametrizedPositiveTests, - GroupNormalizationFusionValueParametrizedTestsFixture, - ::testing::Values( - std::make_tuple(true, Shape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), - std::make_tuple(true, - Shape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{320, 1, 1}, - Shape{1, 320, 1, 1}, - 1, - 1e-5f), - std::make_tuple(true, - Shape{1, 320, 2, 2}, - Shape{1, 320, 1}, - Shape{1, 320, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 320, - 1e-5f), - std::make_tuple(true, - PartialShape{Dimension::dynamic(), 320, Dimension::dynamic(), Dimension::dynamic()}, - Shape{1, 320, 1}, - Shape{1, 320, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 320, - 1e-5f), - std::make_tuple(true, - PartialShape{Dimension::dynamic(), 320}, - Shape{32, 1}, - Shape{32, 1}, - Shape{320}, - Shape{320}, - 32, - 1e-5f), - std::make_tuple(true, - PartialShape{1, 320, Dimension::dynamic()}, - Shape{32, 1}, - Shape{32, 1}, - Shape{320, 1}, - Shape{320, 1}, - 32, - 1e-5f), - std::make_tuple(true, - PartialShape{1, 320, 2, Dimension::dynamic()}, - Shape{1, 32, 1}, - Shape{1, 32, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(true, Shape{2, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), - std::make_tuple(true, - PartialShape{1, 512, Dimension::dynamic(), Dimension::dynamic()}, - Shape{}, - Shape{1, 128, 1}, - Shape{1, 512, 1, 1}, - Shape{512, 1, 1}, - 128, - 1e-6f), - std::make_tuple(true, - Shape{1, 512, 2, 2}, - Shape{1, 64, 1}, - Shape{}, - Shape{1, 512, 1, 1}, - Shape{1, 512, 1, 1}, - 64, - 1e-6f))); - -INSTANTIATE_TEST_SUITE_P( - GroupNormalizationFusionValueParametrizedNegativeTests, - GroupNormalizationFusionValueParametrizedTestsFixture, - ::testing::Values( - std::make_tuple(false, Shape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), - std::make_tuple(false, - Shape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{1, 1, 1, 1}, - 1, - 1e-5f), - std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), - std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), - std::make_tuple(false, - Shape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 32, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(false, - Shape{1, 320, 2, 2}, - Shape{1, 32, 1}, - Shape{1, 1, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(false, - PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, - Shape{}, - Shape{}, - Shape{1, 512, 1, 1}, - Shape{1, 512, 1, 1}, - 100, - 1e-6f))); - -template -class GroupNormalizationFusionTestMultiType { -public: - constexpr static bool positive_test = positive_test; - using T_act_t = T_act; - using T_gn_gamma_t = T_gn_gamma; - using T_gn_beta_t = T_gn_beta; - using T_in_gamma_t = T_in_gamma; - using T_in_beta_t = T_in_beta; -}; +TEST_P(GroupNormalizationFusionTestsFixture_u8, GroupNormalizationFusionTests_u8) { + GroupNormalizationFusionTestsFixture_u8::TestBody(); +} -template -class GroupNormalizationFusionTypeParametrizedTestsFixture : public ::testing::Test {}; - -using GroupNormalizationFusionPositiveTestTypes = - ::testing::Types, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType>; - -using GroupNormalizationFusionNegativeTestTypes = - ::testing::Types, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType, - GroupNormalizationFusionTestMultiType>; - -TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedTestsFixture); - -TYPED_TEST_P(GroupNormalizationFusionTypeParametrizedTestsFixture, GroupNormalizationFusionTypeParametrizedTests) { - constexpr bool positive_test = TypeParam::positive_test; - - typedef TypeParam::T_act_t T_act_t; - typedef TypeParam::T_gn_gamma_t T_gn_gamma_t; - typedef TypeParam::T_gn_beta_t T_gn_beta_t; - typedef TypeParam::T_in_gamma_t T_in_gamma_t; - typedef TypeParam::T_in_beta_t T_in_beta_t; - - constexpr auto T_act_elem_t = element::from(); - constexpr auto T_gn_gamma_elem_t = element::from(); - constexpr auto T_gn_beta_elem_t = element::from(); - constexpr auto T_in_gamma_elem_t = element::from(); - constexpr auto T_in_beta_elem_t = element::from(); - - typedef ov::element_type_traits::value_type T_act_store_t; - typedef ov::element_type_traits::value_type T_gn_gamma_store_t; - typedef ov::element_type_traits::value_type T_gn_beta_store_t; - typedef ov::element_type_traits::value_type T_in_gamma_store_t; - typedef ov::element_type_traits::value_type T_in_beta_store_t; - - auto data_shape = Shape{1, 320, 2, 2}; - auto instance_norm_gamma_shape = Shape{1, 32, 1}; - auto instance_norm_beta_shape = Shape{1, 32, 1}; - auto group_norm_gamma_shape = Shape{1, 320, 1, 1}; - auto group_norm_beta_shape = Shape{1, 320, 1, 1}; - - auto num_channels = 320ull; - auto num_groups = 32; - auto epsilon = 1e-5f; - - auto instance_norm_gamma_vals = - test::utils::generateVector(shape_size(instance_norm_gamma_shape)); - auto instance_norm_beta_vals = test::utils::generateVector(shape_size(instance_norm_beta_shape)); - auto group_norm_gamma_vals = test::utils::generateVector(shape_size(group_norm_gamma_shape)); - auto group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); - - std::shared_ptr model(nullptr), model_ref(nullptr); - { - auto input = std::make_shared(T_act_elem_t, data_shape); - auto pre_mvn_shape_const = - op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(num_groups), -1}); - auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); - - auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {1}); - auto mvn = - std::make_shared(pre_mvn_reshape, mvn_axes_const, true, epsilon, op::MVNEpsMode::INSIDE_SQRT); - - auto instance_norm_gamma_const = - op::v0::Constant::create(T_in_gamma_elem_t, instance_norm_gamma_shape, instance_norm_gamma_vals); - auto instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); - - auto instance_norm_beta_const = - op::v0::Constant::create(T_in_beta_elem_t, instance_norm_beta_shape, instance_norm_beta_vals); - auto instance_norm_beta_add = - std::make_shared(instance_norm_gamma_multiply, instance_norm_beta_const); - - auto post_instance_norm_shape = std::make_shared(input); - - auto post_instance_norm_reshape = - std::make_shared(instance_norm_beta_add, post_instance_norm_shape, true); - - auto group_norm_gamma_const = - op::v0::Constant::create(T_gn_gamma_elem_t, group_norm_gamma_shape, group_norm_gamma_vals); - auto group_norm_gamma_multiply = - std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); - - auto group_norm_beta_const = - op::v0::Constant::create(T_gn_beta_elem_t, group_norm_beta_shape, group_norm_beta_vals); - auto group_norm_beta_add = std::make_shared(group_norm_gamma_multiply, group_norm_beta_const); - - model = std::make_shared(NodeVector{group_norm_beta_add}, ParameterVector{input}); - - pass::Manager m; - m.register_pass(); - OV_ASSERT_NO_THROW(m.run_passes(model)); - } +TEST_P(GroupNormalizationFusionTestsFixture_u16, GroupNormalizationFusionTests_u16) { + GroupNormalizationFusionTestsFixture_u16::TestBody(); +} - if (positive_test) { - auto input = std::make_shared(T_act_elem_t, data_shape); +TEST_P(GroupNormalizationFusionTestsFixture_u32, GroupNormalizationFusionTests_u32) { + GroupNormalizationFusionTestsFixture_u32::TestBody(); +} - auto group_norm_gamma_corr_vals = group_norm_gamma_vals; +TEST_P(GroupNormalizationFusionTestsFixture_u64, GroupNormalizationFusionTests_u64) { + GroupNormalizationFusionTestsFixture_u64::TestBody(); +} - for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) - group_norm_gamma_corr_vals[i] /= instance_norm_gamma_vals[i % num_groups]; +TEST_P(GroupNormalizationFusionTestsFixture_i8, GroupNormalizationFusionTests_i8) { + GroupNormalizationFusionTestsFixture_i8::TestBody(); +} - auto group_norm_gamma_1d = - op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_corr_vals); +TEST_P(GroupNormalizationFusionTestsFixture_i16, GroupNormalizationFusionTests_i16) { + GroupNormalizationFusionTestsFixture_i16::TestBody(); +} - auto group_norm_beta_corr_vals = group_norm_beta_vals; - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] -= (group_norm_gamma_corr_vals[i] * instance_norm_beta_vals[i % num_groups]) / - instance_norm_gamma_vals[i % num_groups]; - auto group_norm_beta_1d = - op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); +TEST_P(GroupNormalizationFusionTestsFixture_i32, GroupNormalizationFusionTests_i32) { + GroupNormalizationFusionTestsFixture_i32::TestBody(); +} - auto group_norm = std::make_shared(input, - group_norm_gamma_1d, - group_norm_beta_1d, - num_groups, - epsilon); +TEST_P(GroupNormalizationFusionTestsFixture_i64, GroupNormalizationFusionTests_i64) { + GroupNormalizationFusionTestsFixture_i64::TestBody(); +} - model_ref = std::make_shared(NodeVector{group_norm}, ParameterVector{input}); - } +TEST_P(GroupNormalizationFusionTestsFixture_f8e4m3, GroupNormalizationFusionTests_f8e4m3) { + GroupNormalizationFusionTestsFixture_f8e4m3::TestBody(); +} + +TEST_P(GroupNormalizationFusionTestsFixture_f8e5m2, GroupNormalizationFusionTests_f8e5m2) { + GroupNormalizationFusionTestsFixture_f8e5m2::TestBody(); +} + +TEST_P(GroupNormalizationFusionTestsFixture_f4e2m1, GroupNormalizationFusionTests_f4e2m1) { + GroupNormalizationFusionTestsFixture_f4e2m1::TestBody(); +} - if (positive_test) { - ASSERT_EQ(count_ops_of_type(model), 1); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::ACCURACY); - auto res = fc.compare(model, model_ref); - ASSERT_TRUE(res.valid) << res.message; - } else { - ASSERT_EQ(count_ops_of_type(model), 0); +TEST_P(GroupNormalizationFusionTestsFixture_f8e8m0, GroupNormalizationFusionTests_f8e8m0) { + GroupNormalizationFusionTestsFixture_f8e8m0::TestBody(); +} + +using RawValuesContainer = std::tuple; +using ValuesContainerWithPositiveTestFlag = + std::tuple; + +std::vector valid_vals = { + std::make_tuple(Shape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), + std::make_tuple(Shape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{320, 1, 1}, + Shape{1, 320, 1, 1}, + 1, + 1e-5f), + std::make_tuple(Shape{1, 320, 2, 2}, + Shape{1, 320, 1}, + Shape{1, 320, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 320, + 1e-5f), + std::make_tuple(PartialShape{Dimension::dynamic(), 320, Dimension::dynamic(), Dimension::dynamic()}, + Shape{1, 320, 1}, + Shape{1, 320, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 320, + 1e-5f), + std::make_tuple(PartialShape{Dimension::dynamic(), 320}, + Shape{32, 1}, + Shape{32, 1}, + Shape{320}, + Shape{320}, + 32, + 1e-5f), + std::make_tuple(PartialShape{1, 320, Dimension::dynamic()}, + Shape{32, 1}, + Shape{32, 1}, + Shape{320, 1}, + Shape{320, 1}, + 32, + 1e-5f), + std::make_tuple(PartialShape{1, 320, 2, Dimension::dynamic()}, + Shape{1, 32, 1}, + Shape{1, 32, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(Shape{2, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), + std::make_tuple(PartialShape{1, 512, Dimension::dynamic(), Dimension::dynamic()}, + Shape{}, + Shape{1, 128, 1}, + Shape{1, 512, 1, 1}, + Shape{512, 1, 1}, + 128, + 1e-6f), + std::make_tuple(Shape{1, 512, 2, 2}, + Shape{1, 64, 1}, + Shape{}, + Shape{1, 512, 1, 1}, + Shape{1, 512, 1, 1}, + 64, + 1e-6f)}; + +auto invalid_vals = ::testing::Values( + std::make_tuple(false, Shape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), + std::make_tuple(false, + Shape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1, 1}, + 1, + 1e-5f), + std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), + std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), + std::make_tuple(false, + Shape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 32, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(false, + Shape{1, 320, 2, 2}, + Shape{1, 32, 1}, + Shape{1, 1, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(false, + PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, + Shape{}, + Shape{}, + Shape{1, 512, 1, 1}, + Shape{1, 512, 1, 1}, + 100, + 1e-6f)); + +std::vector add_positive_test_flag_to_vals( + const bool positive_test, + const std::vector& vals) { + std::vector res; + for (const RawValuesContainer& t : vals) { + auto new_val = std::tuple_cat(std::tuple(positive_test), t); + res.push_back(new_val); } + return res; } -REGISTER_TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedTestsFixture, - GroupNormalizationFusionTypeParametrizedTests); +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_f16, + GroupNormalizationFusionTestsFixture_f16, + ::testing::ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f16, + GroupNormalizationFusionTestsFixture_f16, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_bf16, + GroupNormalizationFusionTestsFixture_bf16, + ::testing::ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_bf16, + GroupNormalizationFusionTestsFixture_bf16, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_f32, + GroupNormalizationFusionTestsFixture_f32, + ::testing::ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTests_f32, + GroupNormalizationFusionTestsFixture_f32, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u8, + GroupNormalizationFusionTestsFixture_u8, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u8, + GroupNormalizationFusionTestsFixture_u8, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u16, + GroupNormalizationFusionTestsFixture_u16, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u16, + GroupNormalizationFusionTestsFixture_u16, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u32, + GroupNormalizationFusionTestsFixture_u32, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u32, + GroupNormalizationFusionTestsFixture_u32, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u64, + GroupNormalizationFusionTestsFixture_u64, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u64, + GroupNormalizationFusionTestsFixture_u64, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i8, + GroupNormalizationFusionTestsFixture_i8, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i8, + GroupNormalizationFusionTestsFixture_i8, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i16, + GroupNormalizationFusionTestsFixture_i16, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i16, + GroupNormalizationFusionTestsFixture_i16, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i32, + GroupNormalizationFusionTestsFixture_i32, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i32, + GroupNormalizationFusionTestsFixture_i32, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i64, + GroupNormalizationFusionTestsFixture_i64, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i64, + GroupNormalizationFusionTestsFixture_i64, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e4m3, + GroupNormalizationFusionTestsFixture_f8e4m3, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e4m3, + GroupNormalizationFusionTestsFixture_f8e4m3, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e5m2, + GroupNormalizationFusionTestsFixture_f8e5m2, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e5m2, + GroupNormalizationFusionTestsFixture_f8e5m2, + invalid_vals); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f4e2m1, + GroupNormalizationFusionTestsFixture_f4e2m1, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f4e2m1, + GroupNormalizationFusionTestsFixture_f4e2m1, + invalid_vals); -INSTANTIATE_TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedPositiveTests, - GroupNormalizationFusionTypeParametrizedTestsFixture, - GroupNormalizationFusionPositiveTestTypes); +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e8m0, + GroupNormalizationFusionTestsFixture_f8e8m0, + ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); -INSTANTIATE_TYPED_TEST_SUITE_P(GroupNormalizationFusionTypeParametrizedNegativeTests, - GroupNormalizationFusionTypeParametrizedTestsFixture, - GroupNormalizationFusionNegativeTestTypes); \ No newline at end of file +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e8m0, + GroupNormalizationFusionTestsFixture_f8e8m0, + invalid_vals); \ No newline at end of file From f85df0a5f5324dd34c444bc0747462211611657e Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 12:52:48 +0100 Subject: [PATCH 06/45] Remove GPU plugin specific GroupNormComposition pass --- .../group_norm_composition.cpp | 110 ------------------ .../group_norm_composition.hpp | 17 --- .../src/plugin/transformations_pipeline.cpp | 3 - 3 files changed, 130 deletions(-) delete mode 100644 src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.cpp delete mode 100644 src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.hpp diff --git a/src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.cpp b/src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.cpp deleted file mode 100644 index 1db88916163dd9..00000000000000 --- a/src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "group_norm_composition.hpp" - -#include "openvino/core/rt_info.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/mvn.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/convert.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/group_normalization.hpp" -#include "openvino/op/squeeze.hpp" -#include "openvino/pass/pattern/op/or.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/utils/utils.hpp" - -namespace ov::intel_gpu { - -GroupNormComposition::GroupNormComposition() { - using namespace ov::pass::pattern; - using ov::pass::pattern::op::Or; - - // Detect Group-Normalization decomposition pattern - // y = scale * MVN(x) + bias - auto data_m = any_input(); - auto pre_reshape_const_m = wrap_type(); - auto pre_reshape_m = wrap_type({data_m, pre_reshape_const_m}); - auto axes_const_m = wrap_type(); - auto mvn_m = wrap_type({pre_reshape_m, axes_const_m}); - auto shapeof_m = wrap_type({data_m}); - auto post_reshape_m = wrap_type({mvn_m, shapeof_m}); - auto scale_const_m = wrap_type(); - auto convert_scale_const_m = wrap_type({scale_const_m}); - auto scale_m = std::make_shared(OutputVector{scale_const_m, convert_scale_const_m}); - auto mul_m = wrap_type({post_reshape_m, scale_m}); - auto bias_const_m = wrap_type(); - auto convert_bias_const_m = wrap_type({bias_const_m}); - auto bias_m = std::make_shared(OutputVector{bias_const_m, convert_bias_const_m}); - auto add_m = wrap_type({mul_m, bias_m}); - - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { - const auto& pattern_map = m.get_pattern_value_map(); - - auto data = pattern_map.at(data_m); - auto data_pshape = data.get_partial_shape(); - // Feature dim should be static. - if (data_pshape[1].is_dynamic()) { - return false; - } - auto feature_dim = data_pshape[1].get_max_length(); - - auto scale = pattern_map.at(scale_const_m); - { - // The total number of elements in scale must be equal to feature_dim. - auto const_scale = ov::as_type_ptr(scale.get_node_shared_ptr()); - auto const_scale_shape = const_scale->get_output_shape(0); - int64_t const_scale_size = 1; - for (auto& dim : const_scale_shape) { - const_scale_size *= dim; - } - if (const_scale_size != feature_dim) { - return false; - } - } - if (pattern_map.count(convert_scale_const_m) != 0) { - scale = pattern_map.at(convert_scale_const_m); - } - auto scale_1d = std::make_shared(scale); - auto bias = pattern_map.at(bias_const_m); - { - // The total number of elements in bias must be equal to feature_dim. - auto const_bias = ov::as_type_ptr(bias.get_node_shared_ptr()); - auto const_bias_shape = const_bias->get_output_shape(0); - int64_t const_bias_size = 1; - for (auto& dim : const_bias_shape) { - const_bias_size *= dim; - } - if (const_bias_size != feature_dim) { - return false; - } - } - if (pattern_map.count(convert_bias_const_m) != 0) { - bias = pattern_map.at(convert_bias_const_m); - } - auto bias_1d = std::make_shared(bias); - - auto pre_reshape = ov::as_type_ptr(pattern_map.at(pre_reshape_m).get_node_shared_ptr()); - auto pre_reshape_pshape = pre_reshape->get_output_partial_shape(0); - auto num_groups = pre_reshape_pshape[1].get_max_length(); - - auto mvn = ov::as_type_ptr(pattern_map.at(mvn_m).get_node_shared_ptr()); - - auto group_norm = std::make_shared(data, scale_1d, bias_1d, num_groups, mvn->get_eps()); - - group_norm->set_friendly_name(m.get_match_root()->get_friendly_name()); - ov::copy_runtime_info(m.get_matched_nodes(), group_norm); - ov::replace_node(m.get_match_root(), group_norm); - - return true; - }; - - auto m = std::make_shared(add_m, "GroupNormComposition"); - this->register_matcher(m, callback); -} - -} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.hpp b/src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.hpp deleted file mode 100644 index 6f9fca08696791..00000000000000 --- a/src/plugins/intel_gpu/src/plugin/transformations/group_norm_composition.hpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/pass/graph_rewrite.hpp" - -namespace ov::intel_gpu { - -class GroupNormComposition : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("GroupNormComposition"); - GroupNormComposition(); -}; - -} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 1a37abd03685ec..ab70e737035c8d 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -82,7 +82,6 @@ #include "plugin/transformations/unsqueeze_broadcast_reshape_matmul_fusion.hpp" #include "plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.hpp" #include "plugin/transformations/increase_position_ids_precision.hpp" -#include "plugin/transformations/group_norm_composition.hpp" #include "plugin/transformations/dynamic_quantize_fully_connected.hpp" #include "plugin/transformations/optimize_subsequent_reshapes.hpp" #include "plugin/transformations/lora_horizontal_fusion.hpp" @@ -432,8 +431,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { return !is_type(next_node); }); - manager.register_pass(); - // Disable subtract folding only for the dGPUs to meet the requirements of oneDNN: // it expects to have the same data type for weights and zero points (apply it only for u8 data type, since other compression // types are not supported by oneDNN) From acb8d20c17e3a9f39c2a5e52974ca8bf7d903bcd Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 12:55:45 +0100 Subject: [PATCH 07/45] Fix RTTI macro in GroupNormalizationFusion header file --- .../common_optimizations/group_normalization_fusion.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp index eb1977583b9654..859da4d9a271f9 100644 --- a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp @@ -27,6 +27,6 @@ class TRANSFORMATIONS_API GroupNormalizationFusion; class ov::pass::GroupNormalizationFusion : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("GroupNormalizationFusion", "0"); + OPENVINO_MATCHER_PASS_RTTI("GroupNormalizationFusion"); GroupNormalizationFusion(); }; From 3e001773d7d25e6dc3c9fd84bf37d991d0d4ef64 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:00:16 +0100 Subject: [PATCH 08/45] Override TestBody() method in GroupNormalizationFusionTestsFixture --- .../common_optimizations/group_normalization_fusion_tests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index aea795a0826f31..d80eb6d61017a9 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -43,7 +43,7 @@ class GroupNormalizationFusionTestsFixture typedef typename ov::element_type_traits::value_type T_in_gamma_store_t; typedef typename ov::element_type_traits::value_type T_in_beta_store_t; - virtual void TestBody() { + void TestBody() override { auto params = GetParam(); auto positive_test = std::get<0>(params); auto data_shape = std::get<1>(params); From e895c8c4c4544726150dba5aec1cf66b53871662 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:05:41 +0100 Subject: [PATCH 09/45] Explain meaning of GroupNormalizationFusion tests parameters --- .../group_normalization_fusion_tests.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index d80eb6d61017a9..70cb1ee81412a0 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -22,14 +22,23 @@ using namespace testing; using namespace ov; +using ValuesContainerWithPositiveTestFlag = + std::tuple; // epsilon + template class GroupNormalizationFusionTestsFixture - : public ::testing::TestWithParam< - std::tuple> { + : public ::testing::TestWithParam { public: static constexpr element::Type_t T_act_elem_t = T_act_elem; static constexpr element::Type_t T_gn_gamma_elem_t = T_gn_gamma_elem; @@ -338,8 +347,6 @@ TEST_P(GroupNormalizationFusionTestsFixture_f8e8m0, GroupNormalizationFusionTest } using RawValuesContainer = std::tuple; -using ValuesContainerWithPositiveTestFlag = - std::tuple; std::vector valid_vals = { std::make_tuple(Shape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), From c059be4a1f646867cac798043281c644ed909c6b Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:08:35 +0100 Subject: [PATCH 10/45] Require providing correct group norm gamma & beta shapes in positive GroupNormalizationFusion tests --- .../group_normalization_fusion_tests.cpp | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 70cb1ee81412a0..60fb5f2adddc23 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -68,33 +68,23 @@ class GroupNormalizationFusionTestsFixture if (positive_test) { if ((instance_norm_gamma_shape != Shape{}) && (shape_size(instance_norm_gamma_shape) != num_groups)) FAIL() << "Unexpected shape of instance norm beta - expected either empty shape (which means that it " - "will not " - "be put in the graph) or shape with exactly num_groups elements that can be merged with the " - "result " - "of MVN."; + "will not be put in the graph) or shape with exactly num_groups elements that can be " + "merged with the result of MVN."; if ((instance_norm_beta_shape != Shape{}) && (shape_size(instance_norm_beta_shape) != num_groups)) FAIL() << "Unexpected shape of instance norm beta - expected either empty shape (which means that it " - "will not " - "be put in the graph) or shape with exactly num_groups elements that can be merged with the " - "result " - "of MVN."; + "will not be put in the graph) or shape with exactly num_groups elements that can be " + "merged with the result of MVN."; - if ((group_norm_gamma_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) + if (shape_size(group_norm_gamma_shape) != num_channels) FAIL() - << "Unexpected shape of group norm gamma - expected either empty shape (which means that it will " - "not " - "be put in the graph) or shape with exactly num_channels elements that can be merged with the " - "result " - "of instance norm."; + << "Unexpected shape of group norm gamma - expected shape with exactly num_channels elements that " + "can be merged with the result of instance norm."; - if ((group_norm_beta_shape != Shape{}) && (shape_size(group_norm_gamma_shape) != num_channels)) + if (shape_size(group_norm_beta_shape) != num_channels) FAIL() - << "Unexpected shape of group norm beta - expected either empty shape (which means that it will " - "not " - "be put in the graph) or shape with exactly num_channels elements that can be merged with the " - "result " - "of instance norm."; + << "Unexpected shape of group norm beta - expected shape with exactly num_channels elements that " + "can be merged with the result of instance norm."; } auto instance_norm_gamma_present = (instance_norm_gamma_shape != Shape{}); auto instance_norm_beta_present = (instance_norm_beta_shape != Shape{}); From ad5536286e36a827c78d51df40409ae25c9763f0 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:11:39 +0100 Subject: [PATCH 11/45] Use dedicated Constant ctor to create scalar constants in GroupNormalizationFusion tests --- .../group_normalization_fusion_tests.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 60fb5f2adddc23..3261747a411bb3 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -220,9 +220,7 @@ class GroupNormalizationFusionTestsFixture op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_vals); } } else { - group_norm_gamma_1d = op::v0::Constant::create(T_gn_gamma_elem_t, - Shape{num_channels}, - std::vector(num_channels, 1)); + group_norm_gamma_1d = std::make_shared(T_gn_gamma_elem_t, Shape{num_channels}, 1); } if (group_norm_beta_present) { @@ -231,9 +229,7 @@ class GroupNormalizationFusionTestsFixture op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_vals); } } else { - group_norm_beta_1d = op::v0::Constant::create(T_gn_beta_elem_t, - Shape{num_channels}, - std::vector(num_channels, 0)); + group_norm_beta_1d = std::make_shared(T_gn_beta_elem_t, Shape{num_channels}, 0); } auto group_norm = std::make_shared(input, From 10ed10e8e1822513f5e06a93575377727420b0e1 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:16:42 +0100 Subject: [PATCH 12/45] Avoid Shape->PartialShape conversion for in/out tensors in GroupNormalizationFusion tests --- .../group_normalization_fusion_tests.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 3261747a411bb3..bd7f7e1c371232 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -335,7 +335,7 @@ TEST_P(GroupNormalizationFusionTestsFixture_f8e8m0, GroupNormalizationFusionTest using RawValuesContainer = std::tuple; std::vector valid_vals = { - std::make_tuple(Shape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), + std::make_tuple(PartialShape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), std::make_tuple(Shape{1, 320, 2, 2}, Shape{1, 1, 1}, Shape{1, 1, 1}, @@ -343,7 +343,7 @@ std::vector valid_vals = { Shape{1, 320, 1, 1}, 1, 1e-5f), - std::make_tuple(Shape{1, 320, 2, 2}, + std::make_tuple(PartialShape{1, 320, 2, 2}, Shape{1, 320, 1}, Shape{1, 320, 1}, Shape{320, 1, 1}, @@ -378,7 +378,7 @@ std::vector valid_vals = { Shape{320, 1, 1}, 32, 1e-5f), - std::make_tuple(Shape{2, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), + std::make_tuple(PartialShape{2, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), std::make_tuple(PartialShape{1, 512, Dimension::dynamic(), Dimension::dynamic()}, Shape{}, Shape{1, 128, 1}, @@ -386,7 +386,7 @@ std::vector valid_vals = { Shape{512, 1, 1}, 128, 1e-6f), - std::make_tuple(Shape{1, 512, 2, 2}, + std::make_tuple(PartialShape{1, 512, 2, 2}, Shape{1, 64, 1}, Shape{}, Shape{1, 512, 1, 1}, @@ -395,19 +395,19 @@ std::vector valid_vals = { 1e-6f)}; auto invalid_vals = ::testing::Values( - std::make_tuple(false, Shape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), + std::make_tuple(false, PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), std::make_tuple(false, - Shape{1, 320, 2, 2}, + PartialShape{1, 320, 2, 2}, Shape{1, 1, 1}, Shape{1, 1, 1}, Shape{1, 1, 1}, Shape{1, 1, 1, 1}, 1, 1e-5f), - std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), - std::make_tuple(false, Shape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), + std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), + std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), std::make_tuple(false, - Shape{1, 320, 2, 2}, + PartialShape{1, 320, 2, 2}, Shape{1, 1, 1}, Shape{1, 32, 1}, Shape{320, 1, 1}, @@ -415,7 +415,7 @@ auto invalid_vals = ::testing::Values( 32, 1e-5f), std::make_tuple(false, - Shape{1, 320, 2, 2}, + PartialShape{1, 320, 2, 2}, Shape{1, 32, 1}, Shape{1, 1, 1}, Shape{320, 1, 1}, From 876470ef2e549dcb2837813e06c35b89b24561c7 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:20:38 +0100 Subject: [PATCH 13/45] Use global testing namespace in GroupNormalizationFusion tests --- .../group_normalization_fusion_tests.cpp | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index bd7f7e1c371232..8ca8507e69d1b3 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -38,7 +38,7 @@ template class GroupNormalizationFusionTestsFixture - : public ::testing::TestWithParam { + : public TestWithParam { public: static constexpr element::Type_t T_act_elem_t = T_act_elem; static constexpr element::Type_t T_gn_gamma_elem_t = T_gn_gamma_elem; @@ -394,7 +394,7 @@ std::vector valid_vals = { 64, 1e-6f)}; -auto invalid_vals = ::testing::Values( +auto invalid_vals = Values( std::make_tuple(false, PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), std::make_tuple(false, PartialShape{1, 320, 2, 2}, @@ -444,7 +444,7 @@ std::vector add_positive_test_flag_to_vals( INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_f16, GroupNormalizationFusionTestsFixture_f16, - ::testing::ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f16, GroupNormalizationFusionTestsFixture_f16, @@ -452,7 +452,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f16, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_bf16, GroupNormalizationFusionTestsFixture_bf16, - ::testing::ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_bf16, GroupNormalizationFusionTestsFixture_bf16, @@ -460,7 +460,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_bf16, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_f32, GroupNormalizationFusionTestsFixture_f32, - ::testing::ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTests_f32, GroupNormalizationFusionTestsFixture_f32, @@ -468,7 +468,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTests_f32, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u8, GroupNormalizationFusionTestsFixture_u8, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u8, GroupNormalizationFusionTestsFixture_u8, @@ -476,7 +476,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u8, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u16, GroupNormalizationFusionTestsFixture_u16, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u16, GroupNormalizationFusionTestsFixture_u16, @@ -484,7 +484,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u16, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u32, GroupNormalizationFusionTestsFixture_u32, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u32, GroupNormalizationFusionTestsFixture_u32, @@ -492,7 +492,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u32, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u64, GroupNormalizationFusionTestsFixture_u64, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u64, GroupNormalizationFusionTestsFixture_u64, @@ -500,7 +500,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u64, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i8, GroupNormalizationFusionTestsFixture_i8, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i8, GroupNormalizationFusionTestsFixture_i8, @@ -508,7 +508,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i8, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i16, GroupNormalizationFusionTestsFixture_i16, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i16, GroupNormalizationFusionTestsFixture_i16, @@ -516,7 +516,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i16, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i32, GroupNormalizationFusionTestsFixture_i32, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i32, GroupNormalizationFusionTestsFixture_i32, @@ -524,7 +524,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i32, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i64, GroupNormalizationFusionTestsFixture_i64, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i64, GroupNormalizationFusionTestsFixture_i64, @@ -532,7 +532,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i64, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e4m3, GroupNormalizationFusionTestsFixture_f8e4m3, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e4m3, GroupNormalizationFusionTestsFixture_f8e4m3, @@ -540,7 +540,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e4m3 INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e5m2, GroupNormalizationFusionTestsFixture_f8e5m2, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e5m2, GroupNormalizationFusionTestsFixture_f8e5m2, @@ -548,7 +548,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e5m2 INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f4e2m1, GroupNormalizationFusionTestsFixture_f4e2m1, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f4e2m1, GroupNormalizationFusionTestsFixture_f4e2m1, @@ -556,7 +556,7 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f4e2m1 INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e8m0, GroupNormalizationFusionTestsFixture_f8e8m0, - ::testing::ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); + ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e8m0, GroupNormalizationFusionTestsFixture_f8e8m0, From 85f0afd66619b25b6762ffed9355d1b5b6ba13fd Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:22:21 +0100 Subject: [PATCH 14/45] Another update of copyright notice --- .../common_optimizations/group_normalization_fusion.hpp | 2 +- .../common_optimizations/group_normalization_fusion.cpp | 2 +- .../common_optimizations/group_normalization_fusion_tests.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp index 859da4d9a271f9..95c30ed2135298 100644 --- a/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2025 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 4d37ca650cd619..3fad152492df25 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2025 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 8ca8507e69d1b3..cdf0563b2395ad 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2025 Intel Corporation +// Copyright (C) 2018-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // From d8f2becb494cd79ae94cef70f5e27849df0c2cd6 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:31:31 +0100 Subject: [PATCH 15/45] Use const references where possible in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 61 +++++++++---------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 3fad152492df25..d7883991f1bdbf 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -60,10 +60,10 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto input = pattern_map.at(input_m); - auto input_ps = input.get_partial_shape(); + const auto& input = pattern_map.at(input_m); + const auto& input_ps = input.get_partial_shape(); - auto T = input.get_element_type(); + const auto& T = input.get_element_type(); // this pattern supports only real and not quantized data types if ((!T.is_real()) || (T.is_quantized())) @@ -77,8 +77,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (input_ps[1].is_dynamic()) return false; - auto pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m); - auto pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape(); + const auto& pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m); + const auto& pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape(); // expecting 3D static tensor as pre-MVN reshape input: // (batch_size, num_groups, -1) @@ -89,8 +89,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (pre_mvn_reshape_out_ps[1].is_dynamic()) return false; - auto num_channels = input_ps[1].get_max_length(); - auto num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); + const auto& num_channels = input_ps[1].get_max_length(); + const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); // number of channels has to be divisible by number of groups if (num_channels % num_groups != 0) @@ -107,16 +107,16 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length()) return false; - auto post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m); - auto post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape(); + const auto& post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m); + const auto& post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape(); // post instance norm shape has to be same as in pattern input: // (batch_size, num_channels, height, width) if (post_instance_norm_reshape_out_ps != input_ps) return false; - auto group_norm_gamma = pattern_map.at(group_norm_gamma_m); - auto group_norm_gamma_ps = group_norm_gamma.get_partial_shape(); + const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m); + const auto& group_norm_gamma_ps = group_norm_gamma.get_partial_shape(); // group_norm_gamma has to share the same data type as // pattern input @@ -132,8 +132,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels) return false; - auto group_norm_beta = pattern_map.at(group_norm_beta_m); - auto group_norm_beta_ps = group_norm_beta.get_partial_shape(); + const auto& group_norm_beta = pattern_map.at(group_norm_beta_m); + const auto& group_norm_beta_ps = group_norm_beta.get_partial_shape(); // group_norm_beta has to share the same data type as // pattern input @@ -152,23 +152,23 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto expected_param_shape = ov::PartialShape({num_channels}); std::shared_ptr group_norm_gamma_1d_m = std::make_shared(group_norm_gamma); - auto group_norm_gamma_1d_out = group_norm_gamma_1d_m->get_default_output(); - auto group_norm_gamma_1d_out_ps = group_norm_gamma_1d_out.get_partial_shape(); + const auto& group_norm_gamma_1d_out = group_norm_gamma_1d_m->get_default_output(); + const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_out.get_partial_shape(); if (group_norm_gamma_1d_out_ps != expected_param_shape) return false; std::shared_ptr group_norm_beta_1d_m = std::make_shared(group_norm_beta); - auto group_norm_beta_1d_out = group_norm_beta_1d_m->get_default_output(); - auto group_norm_beta_1d_out_ps = group_norm_beta_1d_out.get_partial_shape(); + const auto& group_norm_beta_1d_out = group_norm_beta_1d_m->get_default_output(); + const auto& group_norm_beta_1d_out_ps = group_norm_beta_1d_out.get_partial_shape(); if (group_norm_beta_1d_out_ps != expected_param_shape) return false; std::shared_ptr instance_norm_beta_1d_m = nullptr; if (pattern_map.count(instance_norm_beta_m) > 0) { - auto instance_norm_beta = pattern_map.at(instance_norm_beta_m); - auto instance_norm_beta_ps = group_norm_beta.get_partial_shape(); + const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m); + const auto& instance_norm_beta_ps = group_norm_beta.get_partial_shape(); // instance_norm_beta has to share the same data type as // pattern input @@ -197,15 +197,15 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { for (auto i = 0; i < channels_to_groups_ratio; i++) instance_norm_beta_concat_inputs.push_back(instance_norm_beta_1d_m); instance_norm_beta_1d_m = std::make_shared(instance_norm_beta_concat_inputs, 0); - auto instance_norm_beta_1d_out = instance_norm_beta_1d_m->get_default_output(); - auto instance_norm_beta_1d_ps = instance_norm_beta_1d_out.get_partial_shape(); + const auto& instance_norm_beta_1d_out = instance_norm_beta_1d_m->get_default_output(); + const auto& instance_norm_beta_1d_ps = instance_norm_beta_1d_out.get_partial_shape(); if (instance_norm_beta_1d_ps != expected_param_shape) return false; } if (pattern_map.count(instance_norm_gamma_m) > 0) { - auto instance_norm_gamma = pattern_map.at(instance_norm_gamma_m); - auto instance_norm_gamma_ps = group_norm_beta.get_partial_shape(); + const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m); + const auto& instance_norm_gamma_ps = group_norm_beta.get_partial_shape(); // instance_norm_gamma has to share the same data type as // pattern input @@ -235,8 +235,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { for (auto i = 0; i < channels_to_groups_ratio; i++) instance_norm_gamma_concat_inputs.push_back(instance_norm_gamma_1d_m); instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma_concat_inputs, 0); - auto instance_norm_gamma_1d_out = instance_norm_gamma_1d_m->get_default_output(); - auto instance_norm_gamma_1d_ps = instance_norm_gamma_1d_out.get_partial_shape(); + const auto& instance_norm_gamma_1d_out = instance_norm_gamma_1d_m->get_default_output(); + const auto& instance_norm_gamma_1d_ps = instance_norm_gamma_1d_out.get_partial_shape(); if (instance_norm_gamma_1d_ps != expected_param_shape) return false; @@ -263,11 +263,10 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { } } - // we need to be able to cast mvn to MVN layer type - // in order to read actual epsilon value - auto mvn_out = pattern_map.at(mvn_m); - auto mvn = std::dynamic_pointer_cast(mvn_out.get_node_shared_ptr()); - auto epsilon = mvn->get_eps(); + // we need to cast mvn to MVN layer type in order to read actual epsilon value + const auto& mvn_out = pattern_map.at(mvn_m); + const auto& mvn = std::dynamic_pointer_cast(mvn_out.get_node_shared_ptr()); + const auto& epsilon = mvn->get_eps(); // we can finally create GroupNormalization op std::shared_ptr group_norm = std::make_shared(input, @@ -285,4 +284,4 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto m = std::make_shared(group_norm_beta_add_m, matcher_name); this->register_matcher(m, callback); -} +} \ No newline at end of file From 07e4ccf46281658a4e6c03c71e5b1d7cb388c569 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:34:32 +0100 Subject: [PATCH 16/45] Move GroupNormalizationFusion after MVNFusion pass in GPU plugin transformations pipeline --- .../src/plugin/transformations_pipeline.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index ab70e737035c8d..510bd91cb71657 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -340,13 +340,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { auto pass_config = manager.get_pass_config(); manager.set_per_pass_validation(false); - // fuse following ops into GroupNormalization: - // group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta - // note that instance norm related parameters are optional: - // - instance_norm_gamma is assumed to be filled with ones if not present in the graph - // - instance_norm_beta is assumed to be filled with zeros if not present in the graph - manager.register_pass(); - // Temporary solution, global rt info cleanup is needed for (auto& node : func->get_ops()) { ov::enable_constant_folding(node); @@ -415,6 +408,12 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision manager.register_pass(); manager.register_pass(); + // fuse following ops into GroupNormalization: + // group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta + // note that instance norm related parameters are optional: + // - instance_norm_gamma is assumed to be filled with ones if not present in the graph + // - instance_norm_beta is assumed to be filled with zeros if not present in the graph + manager.register_pass(); // decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision manager.register_pass(); // Run these broadcast optimizations earlier to ensure that those are executed before NopElimination/ConstantFolding From b8b3459898e14481e37f908882b270af11b6ac77 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 13:39:41 +0100 Subject: [PATCH 17/45] Use OV ptr cast for MVN in GroupNormalizationFusion pass --- .../common_optimizations/group_normalization_fusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index d7883991f1bdbf..387b6c450e8d5b 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -265,7 +265,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { // we need to cast mvn to MVN layer type in order to read actual epsilon value const auto& mvn_out = pattern_map.at(mvn_m); - const auto& mvn = std::dynamic_pointer_cast(mvn_out.get_node_shared_ptr()); + const auto& mvn = ov::as_type_ptr(mvn_out.get_node_shared_ptr()); const auto& epsilon = mvn->get_eps(); // we can finally create GroupNormalization op From 6c0ea7156a7bda396bdcce10696b0fbc9bfb35e4 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 15:23:41 +0100 Subject: [PATCH 18/45] Add 5d and 6d cases to GroupNormalizationFusion tests + fix formatting --- .../group_normalization_fusion_tests.cpp | 92 ++++++++++--------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index cdf0563b2395ad..fcfb1f42d60814 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -37,8 +37,7 @@ template -class GroupNormalizationFusionTestsFixture - : public TestWithParam { +class GroupNormalizationFusionTestsFixture : public TestWithParam { public: static constexpr element::Type_t T_act_elem_t = T_act_elem; static constexpr element::Type_t T_gn_gamma_elem_t = T_gn_gamma_elem; @@ -343,18 +342,23 @@ std::vector valid_vals = { Shape{1, 320, 1, 1}, 1, 1e-5f), - std::make_tuple(PartialShape{1, 320, 2, 2}, + std::make_tuple(PartialShape{1, 320, 2, 2, 2}, Shape{1, 320, 1}, Shape{1, 320, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, + Shape{320, 1, 1, 1}, + Shape{320, 1, 1, 1}, 320, 1e-5f), - std::make_tuple(PartialShape{Dimension::dynamic(), 320, Dimension::dynamic(), Dimension::dynamic()}, + std::make_tuple(PartialShape{Dimension::dynamic(), + 320, + Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic(), + Dimension::dynamic()}, Shape{1, 320, 1}, Shape{1, 320, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, + Shape{320, 1, 1, 1, 1}, + Shape{320, 1, 1, 1, 1}, 320, 1e-5f), std::make_tuple(PartialShape{Dimension::dynamic(), 320}, @@ -394,42 +398,42 @@ std::vector valid_vals = { 64, 1e-6f)}; -auto invalid_vals = Values( - std::make_tuple(false, PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), - std::make_tuple(false, - PartialShape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{1, 1, 1, 1}, - 1, - 1e-5f), - std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), - std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), - std::make_tuple(false, - PartialShape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 32, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(false, - PartialShape{1, 320, 2, 2}, - Shape{1, 32, 1}, - Shape{1, 1, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(false, - PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, - Shape{}, - Shape{}, - Shape{1, 512, 1, 1}, - Shape{1, 512, 1, 1}, - 100, - 1e-6f)); +auto invalid_vals = + Values(std::make_tuple(false, PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), + std::make_tuple(false, + PartialShape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1, 1}, + 1, + 1e-5f), + std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), + std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), + std::make_tuple(false, + PartialShape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 32, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(false, + PartialShape{1, 320, 2, 2}, + Shape{1, 32, 1}, + Shape{1, 1, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(false, + PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, + Shape{}, + Shape{}, + Shape{1, 512, 1, 1}, + Shape{1, 512, 1, 1}, + 100, + 1e-6f)); std::vector add_positive_test_flag_to_vals( const bool positive_test, From 8f2c63e89f44db1326ff911a98aee9aff49ded76 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 16 Jan 2025 15:25:45 +0100 Subject: [PATCH 19/45] Use predicates for type & shape checks that don't depend on other nodes in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 93 ++++++++----------- 1 file changed, 39 insertions(+), 54 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 387b6c450e8d5b..ced652fc5edb8c 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -25,35 +25,63 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { MATCHER_SCOPE(GroupNormalizationFusion); - auto input_m = ov::pass::pattern::any_input(); + auto has_real_not_quantized_type = [](const ov::Output& output) -> bool { + const auto& T = output.get_element_type(); + return (T.is_real() && (!T.is_quantized())); + }; + + auto has_integral_type = [](const ov::Output& output) -> bool { + const auto& T = output.get_element_type(); + return (T.is_integral()); + }; - auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type(); - auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type({input_m, pre_mvn_shape_const_m}); + auto has_at_least_2d_shape = [](const ov::Output& output) -> bool { + const auto& output_ps = output.get_partial_shape(); + return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2); + }; - auto axes_const_m = ov::pass::pattern::wrap_type(); + auto input_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of( + {has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)})); + + auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type(ov::pass::pattern::all_of( + {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); + auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type( + {input_m, pre_mvn_shape_const_m}, + ov::pass::pattern::all_of( + {has_real_not_quantized_type, ov::pass::pattern::rank_equals(3), ov::pass::pattern::has_static_dim(1)})); + + auto axes_const_m = ov::pass::pattern::wrap_type(ov::pass::pattern::all_of( + {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); auto mvn_m = ov::pass::pattern::wrap_type({pre_mvn_reshape_m, axes_const_m}); - auto instance_norm_gamma_m = ov::pass::pattern::any_input(); + auto instance_norm_gamma_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto instance_norm_gamma_multiply_m = ov::pass::pattern::wrap_type({mvn_m, instance_norm_gamma_m}); auto instance_norm_opt_gamma_m = std::make_shared(ov::OutputVector{mvn_m, instance_norm_gamma_multiply_m}); - auto instance_norm_beta_m = ov::pass::pattern::any_input(); + auto instance_norm_beta_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto instance_norm_beta_add_m = ov::pass::pattern::wrap_type({instance_norm_opt_gamma_m, instance_norm_beta_m}); auto instance_norm_opt_gamma_opt_beta_m = std::make_shared( ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m}); - auto post_instance_norm_shape_m = ov::pass::pattern::any_input(); + auto post_instance_norm_shape_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of( + {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); auto post_instance_norm_reshape_m = ov::pass::pattern::wrap_type( - {instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}); + {instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}, + ov::pass::pattern::all_of( + {has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)})); - auto group_norm_gamma_m = ov::pass::pattern::any_input(); + auto group_norm_gamma_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto group_norm_gamma_multiply_m = ov::pass::pattern::wrap_type({post_instance_norm_reshape_m, group_norm_gamma_m}); - auto group_norm_beta_m = ov::pass::pattern::any_input(); + auto group_norm_beta_m = ov::pass::pattern::any_input( + ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); auto group_norm_beta_add_m = ov::pass::pattern::wrap_type({group_norm_gamma_multiply_m, group_norm_beta_m}); @@ -65,30 +93,9 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& T = input.get_element_type(); - // this pattern supports only real and not quantized data types - if ((!T.is_real()) || (T.is_quantized())) - return false; - - // expecting at least 2D tensor as pattern input: - // (batch_size, num_channels, ...) - if (input_ps.size() < 2) - return false; - // channel dimension has to be static, all other dimensions in input can be dynamic - if (input_ps[1].is_dynamic()) - return false; - const auto& pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m); const auto& pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape(); - // expecting 3D static tensor as pre-MVN reshape input: - // (batch_size, num_groups, -1) - if (pre_mvn_reshape_out_ps.size() != 3) - return false; - - // channel dimension has to be static, all other dimensions in pre-MVN reshape can be dynamic - if (pre_mvn_reshape_out_ps[1].is_dynamic()) - return false; - const auto& num_channels = input_ps[1].get_max_length(); const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); @@ -97,11 +104,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return false; auto channels_to_groups_ratio = num_channels / num_groups; - // MVN input has to have at least two dimensions: - // (batch_size, num_groups, ...) - if (pre_mvn_reshape_out_ps.size() < 2) - return false; - // first dimension of MVN input (batch_size) has to be the same // as in pattern input if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length()) @@ -110,8 +112,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m); const auto& post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape(); - // post instance norm shape has to be same as in pattern input: - // (batch_size, num_channels, height, width) + // post instance norm shape has to be same as in pattern input if (post_instance_norm_reshape_out_ps != input_ps) return false; @@ -123,10 +124,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (group_norm_gamma.get_element_type() != T) return false; - // group_norm_gamma has to be static - if (group_norm_gamma_ps.is_dynamic()) - return false; - // number of elements in group_norm_gamma must be equal to // number of channels if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels) @@ -140,10 +137,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (group_norm_beta.get_element_type() != T) return false; - // group_norm_beta has to be static - if (group_norm_beta_ps.is_dynamic()) - return false; - // number of elements in group_norm_beta must be equal to // number of channels if (ov::shape_size(group_norm_beta.get_shape()) != num_channels) @@ -175,10 +168,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (instance_norm_beta.get_element_type() != T) return false; - // instance_norm_beta has to be static - if (instance_norm_beta_ps.is_dynamic()) - return false; - // number of elements in instance_norm_beta must be equal to // number of groups if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups) @@ -212,10 +201,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (instance_norm_gamma.get_element_type() != T) return false; - // instance_norm_gamma has to be static - if (instance_norm_gamma_ps.is_dynamic()) - return false; - // number of elements in instance_norm_gamma must be equal to // number of groups if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups) From dafeab47a373ea3571b35f05b70a0153214d4ad2 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Fri, 17 Jan 2025 11:40:39 +0100 Subject: [PATCH 20/45] Use ov::pass::pattern namespace in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 58 ++++++++----------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index ced652fc5edb8c..ca63161081b307 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -22,6 +22,8 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" +using namespace ov::pass::pattern; + ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { MATCHER_SCOPE(GroupNormalizationFusion); @@ -40,52 +42,40 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2); }; - auto input_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of( - {has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)})); + auto input_m = any_input(all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)})); - auto pre_mvn_shape_const_m = ov::pass::pattern::wrap_type(ov::pass::pattern::all_of( - {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); - auto pre_mvn_reshape_m = ov::pass::pattern::wrap_type( - {input_m, pre_mvn_shape_const_m}, - ov::pass::pattern::all_of( - {has_real_not_quantized_type, ov::pass::pattern::rank_equals(3), ov::pass::pattern::has_static_dim(1)})); + auto pre_mvn_shape_const_m = + wrap_type(all_of({has_integral_type, rank_equals(1), has_static_dim(0)})); + auto pre_mvn_reshape_m = + wrap_type({input_m, pre_mvn_shape_const_m}, + all_of({has_real_not_quantized_type, rank_equals(3), has_static_dim(1)})); - auto axes_const_m = ov::pass::pattern::wrap_type(ov::pass::pattern::all_of( - {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); - auto mvn_m = ov::pass::pattern::wrap_type({pre_mvn_reshape_m, axes_const_m}); + auto axes_const_m = wrap_type(all_of({has_integral_type, rank_equals(1), has_static_dim(0)})); + auto mvn_m = wrap_type({pre_mvn_reshape_m, axes_const_m}); - auto instance_norm_gamma_m = ov::pass::pattern::any_input( - ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); - auto instance_norm_gamma_multiply_m = - ov::pass::pattern::wrap_type({mvn_m, instance_norm_gamma_m}); + auto instance_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); + auto instance_norm_gamma_multiply_m = wrap_type({mvn_m, instance_norm_gamma_m}); auto instance_norm_opt_gamma_m = std::make_shared(ov::OutputVector{mvn_m, instance_norm_gamma_multiply_m}); - auto instance_norm_beta_m = ov::pass::pattern::any_input( - ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); - auto instance_norm_beta_add_m = - ov::pass::pattern::wrap_type({instance_norm_opt_gamma_m, instance_norm_beta_m}); + auto instance_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); + auto instance_norm_beta_add_m = wrap_type({instance_norm_opt_gamma_m, instance_norm_beta_m}); auto instance_norm_opt_gamma_opt_beta_m = std::make_shared( ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m}); - auto post_instance_norm_shape_m = ov::pass::pattern::any_input(ov::pass::pattern::all_of( - {has_integral_type, ov::pass::pattern::rank_equals(1), ov::pass::pattern::has_static_dim(0)})); - auto post_instance_norm_reshape_m = ov::pass::pattern::wrap_type( - {instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}, - ov::pass::pattern::all_of( - {has_real_not_quantized_type, has_at_least_2d_shape, ov::pass::pattern::has_static_dim(1)})); + auto post_instance_norm_shape_m = any_input(all_of({has_integral_type, rank_equals(1), has_static_dim(0)})); + auto post_instance_norm_reshape_m = + wrap_type({instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}, + all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)})); - auto group_norm_gamma_m = ov::pass::pattern::any_input( - ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); + auto group_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); auto group_norm_gamma_multiply_m = - ov::pass::pattern::wrap_type({post_instance_norm_reshape_m, group_norm_gamma_m}); + wrap_type({post_instance_norm_reshape_m, group_norm_gamma_m}); - auto group_norm_beta_m = ov::pass::pattern::any_input( - ov::pass::pattern::all_of({has_real_not_quantized_type, ov::pass::pattern::has_static_shape()})); - auto group_norm_beta_add_m = - ov::pass::pattern::wrap_type({group_norm_gamma_multiply_m, group_norm_beta_m}); + auto group_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); + auto group_norm_beta_add_m = wrap_type({group_norm_gamma_multiply_m, group_norm_beta_m}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); const auto& input = pattern_map.at(input_m); @@ -267,6 +257,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return true; }; - auto m = std::make_shared(group_norm_beta_add_m, matcher_name); + auto m = std::make_shared(group_norm_beta_add_m, matcher_name); this->register_matcher(m, callback); } \ No newline at end of file From ad7274c153777d55ae679874f0b0aa86e56865ea Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Fri, 17 Jan 2025 11:42:13 +0100 Subject: [PATCH 21/45] Remove redundant has_integral_type predicate from GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index ca63161081b307..a1c872f26a7ab5 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -32,11 +32,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return (T.is_real() && (!T.is_quantized())); }; - auto has_integral_type = [](const ov::Output& output) -> bool { - const auto& T = output.get_element_type(); - return (T.is_integral()); - }; - auto has_at_least_2d_shape = [](const ov::Output& output) -> bool { const auto& output_ps = output.get_partial_shape(); return (output_ps.rank().is_static()) && (output_ps.rank().get_length() >= 2); @@ -44,13 +39,12 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto input_m = any_input(all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)})); - auto pre_mvn_shape_const_m = - wrap_type(all_of({has_integral_type, rank_equals(1), has_static_dim(0)})); + auto pre_mvn_shape_const_m = wrap_type(all_of({rank_equals(1), has_static_dim(0)})); auto pre_mvn_reshape_m = wrap_type({input_m, pre_mvn_shape_const_m}, all_of({has_real_not_quantized_type, rank_equals(3), has_static_dim(1)})); - auto axes_const_m = wrap_type(all_of({has_integral_type, rank_equals(1), has_static_dim(0)})); + auto axes_const_m = wrap_type(all_of({rank_equals(1), has_static_dim(0)})); auto mvn_m = wrap_type({pre_mvn_reshape_m, axes_const_m}); auto instance_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); @@ -63,7 +57,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto instance_norm_opt_gamma_opt_beta_m = std::make_shared( ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m}); - auto post_instance_norm_shape_m = any_input(all_of({has_integral_type, rank_equals(1), has_static_dim(0)})); + auto post_instance_norm_shape_m = any_input(all_of({rank_equals(1), has_static_dim(0)})); auto post_instance_norm_reshape_m = wrap_type({instance_norm_opt_gamma_opt_beta_m, post_instance_norm_shape_m}, all_of({has_real_not_quantized_type, has_at_least_2d_shape, has_static_dim(1)})); From fa84fd3741b28dcadb9d5be2f81a55624fa6ee2c Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Fri, 17 Jan 2025 12:01:44 +0100 Subject: [PATCH 22/45] Simplify accessing nodes partial shapes in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index a1c872f26a7ab5..73b2dc65e4a9d2 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -77,8 +77,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& T = input.get_element_type(); - const auto& pre_mvn_reshape_out = pattern_map.at(pre_mvn_reshape_m); - const auto& pre_mvn_reshape_out_ps = pre_mvn_reshape_out.get_partial_shape(); + const auto& pre_mvn_reshape_out_ps = pattern_map.at(pre_mvn_reshape_m).get_partial_shape(); const auto& num_channels = input_ps[1].get_max_length(); const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); @@ -93,15 +92,14 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length()) return false; - const auto& post_instance_norm_reshape_out = pattern_map.at(post_instance_norm_reshape_m); - const auto& post_instance_norm_reshape_out_ps = post_instance_norm_reshape_out.get_partial_shape(); + const auto& post_instance_norm_reshape_out_ps = + pattern_map.at(post_instance_norm_reshape_m).get_partial_shape(); // post instance norm shape has to be same as in pattern input if (post_instance_norm_reshape_out_ps != input_ps) return false; const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m); - const auto& group_norm_gamma_ps = group_norm_gamma.get_partial_shape(); // group_norm_gamma has to share the same data type as // pattern input @@ -114,7 +112,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return false; const auto& group_norm_beta = pattern_map.at(group_norm_beta_m); - const auto& group_norm_beta_ps = group_norm_beta.get_partial_shape(); // group_norm_beta has to share the same data type as // pattern input @@ -129,15 +126,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto expected_param_shape = ov::PartialShape({num_channels}); std::shared_ptr group_norm_gamma_1d_m = std::make_shared(group_norm_gamma); - const auto& group_norm_gamma_1d_out = group_norm_gamma_1d_m->get_default_output(); - const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_out.get_partial_shape(); + const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape(0); if (group_norm_gamma_1d_out_ps != expected_param_shape) return false; std::shared_ptr group_norm_beta_1d_m = std::make_shared(group_norm_beta); - const auto& group_norm_beta_1d_out = group_norm_beta_1d_m->get_default_output(); - const auto& group_norm_beta_1d_out_ps = group_norm_beta_1d_out.get_partial_shape(); + const auto& group_norm_beta_1d_out_ps = group_norm_beta_1d_m->get_output_partial_shape(0); if (group_norm_beta_1d_out_ps != expected_param_shape) return false; @@ -145,7 +140,6 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { std::shared_ptr instance_norm_beta_1d_m = nullptr; if (pattern_map.count(instance_norm_beta_m) > 0) { const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m); - const auto& instance_norm_beta_ps = group_norm_beta.get_partial_shape(); // instance_norm_beta has to share the same data type as // pattern input @@ -170,15 +164,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { for (auto i = 0; i < channels_to_groups_ratio; i++) instance_norm_beta_concat_inputs.push_back(instance_norm_beta_1d_m); instance_norm_beta_1d_m = std::make_shared(instance_norm_beta_concat_inputs, 0); - const auto& instance_norm_beta_1d_out = instance_norm_beta_1d_m->get_default_output(); - const auto& instance_norm_beta_1d_ps = instance_norm_beta_1d_out.get_partial_shape(); + const auto& instance_norm_beta_1d_ps = instance_norm_beta_1d_m->get_output_partial_shape(0); if (instance_norm_beta_1d_ps != expected_param_shape) return false; } if (pattern_map.count(instance_norm_gamma_m) > 0) { const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m); - const auto& instance_norm_gamma_ps = group_norm_beta.get_partial_shape(); // instance_norm_gamma has to share the same data type as // pattern input @@ -204,8 +196,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { for (auto i = 0; i < channels_to_groups_ratio; i++) instance_norm_gamma_concat_inputs.push_back(instance_norm_gamma_1d_m); instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma_concat_inputs, 0); - const auto& instance_norm_gamma_1d_out = instance_norm_gamma_1d_m->get_default_output(); - const auto& instance_norm_gamma_1d_ps = instance_norm_gamma_1d_out.get_partial_shape(); + const auto& instance_norm_gamma_1d_ps = instance_norm_gamma_1d_m->get_output_partial_shape(0); if (instance_norm_gamma_1d_ps != expected_param_shape) return false; From ca2067da1ed7a83b7dbef76327557d9598783399 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Mon, 20 Jan 2025 15:32:16 +0100 Subject: [PATCH 23/45] Fix typo in one of types in GroupNormalizationFusion tests --- .../group_normalization_fusion_tests.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index fcfb1f42d60814..cc2474353f64d9 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -113,9 +113,9 @@ class GroupNormalizationFusionTestsFixture : public TestWithParam(shape_size(group_norm_gamma_shape)); - auto group_norm_beta_vals = std::vector(); + auto group_norm_beta_vals = std::vector(); if (group_norm_beta_present) - group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); + group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); std::shared_ptr model(nullptr), model_ref(nullptr); { @@ -564,4 +564,4 @@ INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e8m0, INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e8m0, GroupNormalizationFusionTestsFixture_f8e8m0, - invalid_vals); \ No newline at end of file + invalid_vals); From 770f4f94c9d69fd2116165174a742aebf56ad88e Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:10:02 +0100 Subject: [PATCH 24/45] Remove unused include files from GroupNormalizationFusion pass --- .../common_optimizations/group_normalization_fusion.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 73b2dc65e4a9d2..a0120434d7dc6d 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -9,13 +9,11 @@ #include "openvino/op/add.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" -#include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/group_normalization.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/mvn.hpp" #include "openvino/op/reshape.hpp" -#include "openvino/op/shape_of.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/subtract.hpp" #include "openvino/pass/pattern/op/or.hpp" @@ -244,4 +242,4 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto m = std::make_shared(group_norm_beta_add_m, matcher_name); this->register_matcher(m, callback); -} \ No newline at end of file +} From c8cd6aa120f5e7efc25249a2a7b6c0ec44caaf2b Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:15:26 +0100 Subject: [PATCH 25/45] Fix handling instance norm gamma & beta in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 56 +++++++++---------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index a0120434d7dc6d..5e5dcd8ceebd6f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -7,15 +7,13 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/op/add.hpp" -#include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" -#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/group_normalization.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/mvn.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/squeeze.hpp" -#include "openvino/op/subtract.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" @@ -135,6 +133,13 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (group_norm_beta_1d_out_ps != expected_param_shape) return false; + auto gather_axis_const_m = op::v0::Constant::create(element::i64, Shape{1}, {0}); + auto gather_indices_vals = std::vector(); + for (auto i = 0; i < num_groups; i++) + gather_indices_vals.insert(gather_indices_vals.end(), channels_to_groups_ratio, i); + auto gather_indices_const_m = + op::v0::Constant::create(element::i64, Shape{static_cast(num_channels)}, gather_indices_vals); + std::shared_ptr instance_norm_beta_1d_m = nullptr; if (pattern_map.count(instance_norm_beta_m) > 0) { const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m); @@ -158,13 +163,20 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { } else { instance_norm_beta_1d_m = std::make_shared(instance_norm_beta); } - ov::OutputVector instance_norm_beta_concat_inputs; - for (auto i = 0; i < channels_to_groups_ratio; i++) - instance_norm_beta_concat_inputs.push_back(instance_norm_beta_1d_m); - instance_norm_beta_1d_m = std::make_shared(instance_norm_beta_concat_inputs, 0); + + instance_norm_beta_1d_m = std::make_shared(instance_norm_beta_1d_m, + gather_indices_const_m, + gather_axis_const_m); + const auto& instance_norm_beta_1d_ps = instance_norm_beta_1d_m->get_output_partial_shape(0); if (instance_norm_beta_1d_ps != expected_param_shape) return false; + + // group_norm_beta = group_norm_gamma * instance_norm_beta + group_norm_beta + auto group_norm_beta_corr_multiply_m = + std::make_shared(group_norm_gamma_1d_m, instance_norm_beta_1d_m); + group_norm_beta_1d_m = + std::make_shared(group_norm_beta_corr_multiply_m, group_norm_beta_1d_m); } if (pattern_map.count(instance_norm_gamma_m) > 0) { @@ -190,35 +202,17 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { } else { instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma); } - ov::OutputVector instance_norm_gamma_concat_inputs; - for (auto i = 0; i < channels_to_groups_ratio; i++) - instance_norm_gamma_concat_inputs.push_back(instance_norm_gamma_1d_m); - instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma_concat_inputs, 0); + + instance_norm_gamma_1d_m = std::make_shared(instance_norm_gamma_1d_m, + gather_indices_const_m, + gather_axis_const_m); const auto& instance_norm_gamma_1d_ps = instance_norm_gamma_1d_m->get_output_partial_shape(0); if (instance_norm_gamma_1d_ps != expected_param_shape) return false; - // group_norm_gamma /= instance_norm_gamma + // group_norm_gamma *= instance_norm_gamma group_norm_gamma_1d_m = - std::make_shared(group_norm_gamma_1d_m, instance_norm_gamma_1d_m); - - if (pattern_map.count(instance_norm_beta_m) > 0) { - // group_norm_beta -= group_norm_gamma * instance_norm_beta / instance_norm_gamma - auto group_norm_beta_corr_multiply_m = - std::make_shared(group_norm_gamma_1d_m, instance_norm_beta_1d_m); - auto group_norm_beta_corr_divide_m = - std::make_shared(group_norm_beta_corr_multiply_m, instance_norm_gamma_1d_m); - group_norm_beta_1d_m = - std::make_shared(group_norm_beta_1d_m, group_norm_beta_corr_divide_m); - } - } else { - if (pattern_map.count(instance_norm_beta_m) > 0) { - // group_norm_beta -= group_norm_gamma * instance_norm_beta - auto group_norm_beta_corr_multiply_m = - std::make_shared(group_norm_gamma_1d_m, instance_norm_beta_1d_m); - group_norm_beta_1d_m = - std::make_shared(group_norm_beta_1d_m, group_norm_beta_corr_multiply_m); - } + std::make_shared(group_norm_gamma_1d_m, instance_norm_gamma_1d_m); } // we need to cast mvn to MVN layer type in order to read actual epsilon value From 657b26b11e13e92fa578805f5e5a4e79ec1e5d44 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:20:02 +0100 Subject: [PATCH 26/45] Validate pre-MVN shape and MVN reduction axes in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 98 ++++++++++++++++++- 1 file changed, 94 insertions(+), 4 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 5e5dcd8ceebd6f..9648c5ae9e9e13 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -20,6 +20,35 @@ using namespace ov::pass::pattern; +template ::value, bool> = true> +bool pre_mvn_shape_vals_correct(const std::shared_ptr& pre_mvn_shape_const, + const ov::PartialShape& input_ps, + const ov::Dimension::value_type num_groups) { + bool res = true; + std::vector pre_mvn_shape_vals = pre_mvn_shape_const->get_vector(); + if (input_ps[0].is_dynamic()) { + if (pre_mvn_shape_vals[0] != 0) + res = false; + } else { + if ((pre_mvn_shape_vals[0] != 0) && (pre_mvn_shape_vals[0] != input_ps[0].get_max_length())) + res = false; + } + if ((pre_mvn_shape_vals[1] != 0) && (pre_mvn_shape_vals[1] != num_groups)) + res = false; + if (pre_mvn_shape_vals[2] != -1) + res = false; + return res; +} + +template ::value, bool> = true> +bool mvn_reduction_axes_correct(const std::shared_ptr& mvn_reduction_axes_const) { + bool res = true; + std::vector mvn_reduce_axes = mvn_reduction_axes_const->get_vector(); + if ((mvn_reduce_axes[0] != 2) && (mvn_reduce_axes[0] != -1)) + return false; + return res; +} + ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { MATCHER_SCOPE(GroupNormalizationFusion); @@ -40,8 +69,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { wrap_type({input_m, pre_mvn_shape_const_m}, all_of({has_real_not_quantized_type, rank_equals(3), has_static_dim(1)})); - auto axes_const_m = wrap_type(all_of({rank_equals(1), has_static_dim(0)})); - auto mvn_m = wrap_type({pre_mvn_reshape_m, axes_const_m}); + auto mvn_reduction_axes_const_m = wrap_type(all_of({rank_equals(1), has_static_dim(0)})); + auto mvn_m = wrap_type({pre_mvn_reshape_m, mvn_reduction_axes_const_m}); auto instance_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); auto instance_norm_gamma_multiply_m = wrap_type({mvn_m, instance_norm_gamma_m}); @@ -78,6 +107,51 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& num_channels = input_ps[1].get_max_length(); const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); + // we expect to reshape input in a way that would merge all spatial dimensions + // but leave batch and channel dimensions untouched + const auto& pre_mvn_shape = pattern_map.at(pre_mvn_shape_const_m); + const auto& pre_mvn_shape_const = + ov::as_type_ptr(pattern_map.at(pre_mvn_shape_const_m).get_node_shared_ptr()); + const auto& pre_mvn_shape_out_ps = pre_mvn_shape.get_shape(); + if (pre_mvn_shape_out_ps[0] != 3) + return false; + switch (pre_mvn_shape_const->get_element_type()) { + case ov::element::i8: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::i16: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::i32: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::i64: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::u8: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::u16: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::u32: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + case ov::element::u64: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) + return false; + break; + default: + return false; + } + // number of channels has to be divisible by number of groups if (num_channels % num_groups != 0) return false; @@ -88,15 +162,31 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (input_ps[0].get_max_length() != pre_mvn_reshape_out_ps[0].get_max_length()) return false; + // we expect to execute normalization over last dimension of MVN input + const auto& mvn_reduction_axes = pattern_map.at(mvn_reduction_axes_const_m); + const auto& mvn_reduction_axes_const = + ov::as_type_ptr(mvn_reduction_axes.get_node_shared_ptr()); + const auto& mvn_reduction_axes_out_shape = mvn_reduction_axes.get_shape(); + if (mvn_reduction_axes_out_shape[0] != 1) + return false; + switch (mvn_reduction_axes_const->get_element_type()) { + case ov::element::i32: + mvn_reduction_axes_correct(mvn_reduction_axes_const); + break; + case ov::element::i64: + mvn_reduction_axes_correct(mvn_reduction_axes_const); + break; + default: + break; + } + const auto& post_instance_norm_reshape_out_ps = pattern_map.at(post_instance_norm_reshape_m).get_partial_shape(); - // post instance norm shape has to be same as in pattern input if (post_instance_norm_reshape_out_ps != input_ps) return false; const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m); - // group_norm_gamma has to share the same data type as // pattern input if (group_norm_gamma.get_element_type() != T) From fd6523e5251dc3e64dc91b629c5eb81f2adb0680 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:26:42 +0100 Subject: [PATCH 27/45] Make instance norm gamma & beta explicitly optional in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 9648c5ae9e9e13..d3bcd456e1281e 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -14,7 +14,7 @@ #include "openvino/op/mvn.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/squeeze.hpp" -#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" @@ -73,14 +73,11 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto mvn_m = wrap_type({pre_mvn_reshape_m, mvn_reduction_axes_const_m}); auto instance_norm_gamma_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); - auto instance_norm_gamma_multiply_m = wrap_type({mvn_m, instance_norm_gamma_m}); - auto instance_norm_opt_gamma_m = - std::make_shared(ov::OutputVector{mvn_m, instance_norm_gamma_multiply_m}); + auto instance_norm_opt_gamma_m = optional({mvn_m, instance_norm_gamma_m}); auto instance_norm_beta_m = any_input(all_of({has_real_not_quantized_type, has_static_shape()})); - auto instance_norm_beta_add_m = wrap_type({instance_norm_opt_gamma_m, instance_norm_beta_m}); - auto instance_norm_opt_gamma_opt_beta_m = std::make_shared( - ov::OutputVector{instance_norm_opt_gamma_m, instance_norm_beta_add_m}); + auto instance_norm_opt_gamma_opt_beta_m = + optional({instance_norm_opt_gamma_m, instance_norm_beta_m}); auto post_instance_norm_shape_m = any_input(all_of({rank_equals(1), has_static_dim(0)})); auto post_instance_norm_reshape_m = From 745132be4f64957e4e9fc976365e7466cdde23ce Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:33:08 +0100 Subject: [PATCH 28/45] Add GroupNormalizationFusion shared functional subgraph test --- .../group_normalization_fusion.hpp | 73 +++ .../subgraph/group_normalization_fusion.hpp | 490 ++++++++++++++++++ 2 files changed, 563 insertions(+) create mode 100644 src/tests/functional/plugin/shared/include/subgraph_tests/group_normalization_fusion.hpp create mode 100644 src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp diff --git a/src/tests/functional/plugin/shared/include/subgraph_tests/group_normalization_fusion.hpp b/src/tests/functional/plugin/shared/include/subgraph_tests/group_normalization_fusion.hpp new file mode 100644 index 00000000000000..28f353fa18fcdf --- /dev/null +++ b/src/tests/functional/plugin/shared/include/subgraph_tests/group_normalization_fusion.hpp @@ -0,0 +1,73 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "shared_test_classes/subgraph/group_normalization_fusion.hpp" + +namespace ov { +namespace test { + +TEST_P(GroupNormalizationFusionSubgraphTestsF_f32, GroupNormalizationFusionSubgraphTests_f32) { + GroupNormalizationFusionSubgraphTestsF_f32::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_f16, GroupNormalizationFusionSubgraphTests_f16) { + GroupNormalizationFusionSubgraphTestsF_f16::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_bf16, GroupNormalizationFusionSubgraphTests_bf16) { + GroupNormalizationFusionSubgraphTestsF_bf16::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_u8, GroupNormalizationFusionSubgraphTests_u8) { + GroupNormalizationFusionSubgraphTestsF_u8::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_u16, GroupNormalizationFusionSubgraphTests_u16) { + GroupNormalizationFusionSubgraphTestsF_u16::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_u32, GroupNormalizationFusionSubgraphTests_u32) { + GroupNormalizationFusionSubgraphTestsF_u32::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_u64, GroupNormalizationFusionSubgraphTests_u64) { + GroupNormalizationFusionSubgraphTestsF_u64::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_i8, GroupNormalizationFusionSubgraphTests_i8) { + GroupNormalizationFusionSubgraphTestsF_i8::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_i16, GroupNormalizationFusionSubgraphTests_i16) { + GroupNormalizationFusionSubgraphTestsF_i16::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_i32, GroupNormalizationFusionSubgraphTests_i32) { + GroupNormalizationFusionSubgraphTestsF_i32::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_i64, GroupNormalizationFusionSubgraphTests_i64) { + GroupNormalizationFusionSubgraphTestsF_i64::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_f8e4m3, GroupNormalizationFusionSubgraphTests_f8e4m3) { + GroupNormalizationFusionSubgraphTestsF_f8e4m3::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_f8e5m2, GroupNormalizationFusionSubgraphTests_f8e5m2) { + GroupNormalizationFusionSubgraphTestsF_f8e5m2::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_f4e2m1, GroupNormalizationFusionSubgraphTests_f4e2m1) { + GroupNormalizationFusionSubgraphTestsF_f4e2m1::run(); +} + +TEST_P(GroupNormalizationFusionSubgraphTestsF_f8e8m0, GroupNormalizationFusionSubgraphTests_f8e8m0) { + GroupNormalizationFusionSubgraphTestsF_f8e8m0::run(); +} + +} // namespace test +} // namespace ov diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp new file mode 100644 index 00000000000000..61e40355541a4b --- /dev/null +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -0,0 +1,490 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "common_test_utils/data_utils.hpp" +#include "common_test_utils/ov_test_utils.hpp" +#include "functional_test_utils/crash_handler.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "transformations/common_optimizations/group_normalization_fusion.hpp" + +using namespace testing; + +namespace ov { +namespace test { + +using GroupNormalizationFusionTestBaseValues = + std::tuple; // epsilon + +using GroupNormalizationFusionTransformationsTestValues = + std::tuple; // reference device properties + +template +std::vector> expand_vals(std::vector> old_vals, + std::tuple added_vals) { + std::vector> res; + for (const std::tuple& t : old_vals) { + auto new_tuple = std::tuple_cat(t, added_vals); + res.push_back(new_tuple); + } + return res; +} + +template +class GroupNormalizationFusionTestBase { +public: + static constexpr element::Type T_elem_t = T_elem; + typedef typename ov::element_type_traits::value_type T_store_t; + +protected: + size_t numChannels; + bool instanceNormGammaPresent; + bool instanceNormBetaPresent; + + std::vector instanceNormGammaVals; + std::vector instanceNormBetaVals; + std::vector groupNormGammaVals; + std::vector groupNormBetaVals; + + PartialShape dataShape; + Shape instanceNormGammaShape; + Shape instanceNormBetaShape; + Shape groupNormGammaShape; + Shape groupNormBetaShape; + size_t numGroups; + float epsilon; + + virtual void read_test_parameters() = 0; + + void generate_weights_init_values() { + if (instanceNormGammaPresent) + instanceNormGammaVals = test::utils::generateVector(shape_size(instanceNormGammaShape), 10, 1, 1); + if (instanceNormBetaPresent) + instanceNormBetaVals = test::utils::generateVector(shape_size(instanceNormBetaShape), 10, 1, 2); + groupNormGammaVals = test::utils::generateVector(shape_size(groupNormGammaShape), 10, 1, 3); + groupNormBetaVals = test::utils::generateVector(shape_size(groupNormBetaShape), 10, 1, 4); + } + + std::shared_ptr create_model() { + auto input = std::make_shared(T_elem_t, dataShape); + auto pre_mvn_shape_const = + op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(numGroups), -1}); + auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); + + auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {2}); + auto mvn = + std::make_shared(pre_mvn_reshape, mvn_axes_const, true, epsilon, op::MVNEpsMode::INSIDE_SQRT); + + std::shared_ptr opt_instance_norm_gamma_multiply = mvn; + if (instanceNormGammaPresent) { + auto instance_norm_gamma_const = + op::v0::Constant::create(T_elem_t, instanceNormGammaShape, instanceNormGammaVals); + opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); + } + + std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; + if (instanceNormBetaPresent) { + auto instance_norm_beta_const = + op::v0::Constant::create(T_elem_t, instanceNormBetaShape, instanceNormBetaVals); + opt_instance_norm_beta_add = + std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); + } + + auto post_instance_norm_shape = std::make_shared(input); + + auto post_instance_norm_reshape = + std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); + + auto group_norm_gamma_const = op::v0::Constant::create(T_elem_t, groupNormGammaShape, groupNormGammaVals); + auto group_norm_gamma_multiply = + std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); + + auto group_norm_beta_const = op::v0::Constant::create(T_elem_t, groupNormBetaShape, groupNormBetaVals); + auto group_norm_beta_add = std::make_shared(group_norm_gamma_multiply, group_norm_beta_const); + + return std::make_shared(NodeVector{group_norm_beta_add}, ParameterVector{input}); + } +}; + +template +class GroupNormalizationFusionSubgraphTestsF + : public GroupNormalizationFusionTestBase, + public ov::test::SubgraphBaseTest, + public testing::WithParamInterface { +public: + static std::string getTestCaseName( + const testing::TestParamInfo& obj) { + const auto& params = obj.param; + + const auto& data_shape = std::get<0>(params); + const auto& instance_norm_gamma_shape = std::get<1>(params); + const auto& instance_norm_beta_shape = std::get<2>(params); + const auto& group_norm_gamma_shape = std::get<3>(params); + const auto& group_norm_beta_shape = std::get<4>(params); + const auto& num_groups = std::get<5>(params); + const auto& epsilon = std::get<6>(params); + const auto& positive_test = std::get<7>(params); + const auto& device_name = std::get<8>(params); + const auto& device_properties = std::get<9>(params); + const auto& ref_device_name = std::get<10>(params); + const auto& ref_device_properties = std::get<11>(params); + + std::ostringstream results; + + results << "T=" << T_elem_t << "_"; + results << "Input=" << ov::test::utils::partialShape2str({data_shape}) << "_"; + results << "InstNormGamma=" << ov::test::utils::partialShape2str({instance_norm_gamma_shape}) << "_"; + results << "InstNormBeta=" << ov::test::utils::partialShape2str({instance_norm_beta_shape}) << "_"; + results << "GroupNormGamma=" << ov::test::utils::partialShape2str({group_norm_gamma_shape}) << "_"; + results << "GroupNormBeta=" << ov::test::utils::partialShape2str({group_norm_beta_shape}) << "_"; + results << "NumGroups=" << num_groups << "_"; + results << "Epsilon=" << epsilon << "_"; + results << "PositiveTest=" << std::boolalpha << positive_test << "_"; + results << "Device=" << device_name << "_"; + results << "DeviceCfg=("; + for (auto iter = device_properties.begin(); iter != device_properties.end(); iter++) { + results << iter->first << "=" << iter->second.as(); + if (std::next(iter) != device_properties.end()) + results << "_"; + } + results << ")_"; + results << "RefDevice=" << ref_device_name << "_"; + results << "RefDeviceCfg=("; + for (auto iter = ref_device_properties.begin(); iter != ref_device_properties.end(); iter++) { + results << iter->first << "=" << iter->second.as(); + if (std::next(iter) != ref_device_properties.end()) + results << "_"; + } + results << ")"; + return results.str(); + } + +protected: + bool positiveTest; + std::string targetDeviceName; + ov::AnyMap targetConfiguration; + std::string refDevice; + ov::AnyMap refConfiguration; + + ElementType refInferencePrecision; + ov::CompiledModel compiledRefModel; + ov::InferRequest refInferRequest; + + void TearDown() override { + SubgraphBaseTest::TearDown(); + } + + virtual void read_test_parameters() { + const auto& params = GetParam(); + + dataShape = std::get<0>(params); + if (!dataShape.rank().is_static()) + throw std::runtime_error("Rank of input tensor has to be static!"); + if (dataShape.rank().get_max_length() < 2) + throw std::runtime_error("Expected at least two dimensions in input tensor!"); + if (!dataShape[1].is_static()) + throw std::runtime_error("Channel dimension in input tensor has to be static!"); + + numChannels = static_cast(dataShape[1].get_max_length()); + instanceNormGammaShape = std::get<1>(params); + instanceNormBetaShape = std::get<2>(params); + groupNormGammaShape = std::get<3>(params); + groupNormBetaShape = std::get<4>(params); + numGroups = std::get<5>(params); + epsilon = std::get<6>(params); + positiveTest = std::get<7>(params); + targetDeviceName = std::get<8>(params); + targetConfiguration = std::get<9>(params); + refDevice = std::get<10>(params); + refConfiguration = std::get<11>(params); + + instanceNormGammaPresent = (instanceNormGammaShape != Shape{}); + instanceNormBetaPresent = (instanceNormBetaShape != Shape{}); + + inType = T_elem_t; + outType = T_elem_t; + targetDevice = targetDeviceName; + configuration = targetConfiguration; + + if (positiveTest) { + if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups)) + throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " + "exactly elements"); + if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups)) + throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " + "exactly elements"); + if (shape_size(groupNormGammaShape) != numChannels) + throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); + if (shape_size(groupNormBetaShape) != numChannels) + throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); + + instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups); + instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups); + } + } + + void configure_device() { + if (targetConfiguration.count(ov::hint::inference_precision.name()) <= 0) { + targetConfiguration.insert({ov::hint::inference_precision.name(), T_elem_t}); + } + } + + void configure_ref_device() { + if (refConfiguration.count(ov::hint::inference_precision.name()) <= 0) { + refConfiguration.insert({ov::hint::inference_precision.name(), T_elem_t}); + } + } + + void configure_ref_model() { + // configure input precision + ov::preprocess::PrePostProcessor p(functionRefs); + { + auto& params = functionRefs->get_parameters(); + for (size_t i = 0; i < params.size(); i++) { + if (inType != ov::element::Type_t::undefined) { + p.input(i).tensor().set_element_type(inType); + } + } + } + + // configure output precision + { + auto results = functionRefs->get_results(); + for (size_t i = 0; i < results.size(); i++) { + if (outType != ov::element::Type_t::undefined) { + p.output(i).tensor().set_element_type(outType); + } + } + } + functionRefs = p.build(); + } + + void compile_ref_model() { + if (is_report_stages) { + std::cout << "[ REFERENCE ] `GroupNormalizationFusionSubgraphTestsF::compile_ref_model()` is started" + << std::endl; + } + auto start_time = std::chrono::system_clock::now(); + + configure_ref_model(); + core_configuration(this); + compiledRefModel = core->compile_model(functionRefs, refDevice, refConfiguration); + if (is_report_stages) { + auto end_time = std::chrono::system_clock::now(); + std::chrono::duration duration = end_time - start_time; + std::cout << "[ REFERENCE ] `GroupNormalizationFusionSubgraphTestsF::compile_ref_model()` is finished " + "successfully. Duration is " + << duration.count() << "s" << std::endl; + } + try { + refInferencePrecision = core->get_property(refDevice, ov::hint::inference_precision); + } catch (std::exception& e) { + std::cout << "[ WARNING ] Impossible to get Inference Precision with exception: " << e.what() << std::endl; + } + } + + void init_thresholds() override { + if (!targetStaticShapes.empty()) { + size_t problem_size = shape_size(dataShape.get_shape()); + + abs_threshold = pow(problem_size, 0.5) * test::utils::get_eps_by_ov_type(outType); + rel_threshold = abs_threshold; + } + } + + void infer_ref(const std::map, ov::Tensor>& inputs_ref) { + refInferRequest = compiledRefModel.create_infer_request(); + for (const auto& input : inputs_ref) { + refInferRequest.set_tensor(input.first, input.second); + } + refInferRequest.infer(); + } + + std::vector calculate_refs() { + if (is_report_stages) { + std::cout << "[ REFERENCE ] `GroupNormalizationFusionSubgraphTestsF::calculate_refs()` is started" + << std::endl; + } + auto start_time = std::chrono::system_clock::now(); + + update_ref_model(); + match_parameters(function->get_parameters(), functionRefs->get_parameters()); + + std::map, ov::Tensor> inputs_ref; + for (const auto& param : functionRefs->get_parameters()) { + inputs_ref[param] = inputs.at(matched_parameters[param]); + } + + infer_ref(inputs_ref); + auto outputs = std::vector{}; + for (const auto& output : functionRefs->outputs()) { + outputs.push_back(refInferRequest.get_tensor(output)); + } + if (is_report_stages) { + auto end_time = std::chrono::system_clock::now(); + std::chrono::duration duration = end_time - start_time; + std::cout << "[ REFERENCE ] `GroupNormalizationFusionSubgraphTestsF::calculate_refs()` is finished " + "successfully. Duration is " + << duration.count() << "s" << std::endl; + } + return outputs; + } + + virtual void generate_inputs(const std::vector& targetInputStaticShapes) override { + inputs.clear(); + + auto itTargetShape = targetInputStaticShapes.begin(); + for (const auto& param : function->get_parameters()) { + std::shared_ptr inputNode = param; + for (size_t i = 0; i < param->get_output_size(); i++) { + for (const auto& node : param->get_output_target_inputs(i)) { + std::shared_ptr nodePtr = node.get_node()->shared_from_this(); + for (size_t port = 0; port < nodePtr->get_input_size(); ++port) { + if (nodePtr->get_input_node_ptr(port)->shared_from_this() == inputNode->shared_from_this()) { + const auto& tensor = ov::test::utils::create_and_fill_tensor(inType, *itTargetShape); + inputs.insert({param, tensor}); + break; + } + } + } + } + itTargetShape++; + } + } + +public: + void run() { + is_reported = true; + bool isCurrentTestDisabled = ov::test::utils::current_test_is_disabled(); + + ov::test::utils::PassRate::Statuses status = isCurrentTestDisabled + ? ov::test::utils::PassRate::Statuses::SKIPPED + : ov::test::utils::PassRate::Statuses::CRASHED; + + if (isCurrentTestDisabled) + GTEST_SKIP() << "Disabled test due to configuration" << std::endl; + + // in case of crash jump will be made and work will be continued + auto crashHandler = std::unique_ptr(new ov::test::utils::CrashHandler()); + + // place to jump in case of a crash + int jmpRes = 0; +#ifdef _WIN32 + jmpRes = setjmp(ov::test::utils::env); +#else + jmpRes = sigsetjmp(ov::test::utils::env, 1); +#endif + if (jmpRes == ov::test::utils::JMP_STATUS::ok) { + crashHandler->StartTimer(); + std::string errorMessage; + try { + read_test_parameters(); + generate_weights_init_values(); + functionRefs = create_model(); + function = functionRefs->clone(); + pass::Manager m; + m.register_pass(); + OV_ASSERT_NO_THROW(m.run_passes(function)); + + summary.setDeviceName(targetDevice); + summary.updateOPsStats(function, status, rel_influence_coef); + if (positiveTest) { + ASSERT_EQ(count_ops_of_type(functionRefs), 0); + ASSERT_EQ(count_ops_of_type(function), 1); + + if (!function->is_dynamic()) { + configure_device(); + configure_ref_device(); + auto input_shapes = static_partial_shapes_to_test_representation({dataShape}); + init_input_shapes(input_shapes); + ASSERT_FALSE(targetStaticShapes.empty() && !function->get_parameters().empty()) + << "Target Static Shape is empty!!!"; + compile_model(); + compile_ref_model(); + for (const auto& targetStaticShapeVec : targetStaticShapes) { + generate_inputs(targetStaticShapeVec); + validate(); + } + } + } else { + ASSERT_EQ(count_ops_of_type(functionRefs), 0); + ASSERT_EQ(count_ops_of_type(function), 0); + } + status = ov::test::utils::PassRate::Statuses::PASSED; + } catch (const std::exception& ex) { + if (callback_exception != nullptr) { + // exception will be checked by callback. + callback_exception(ex); + return; + } else { + status = ov::test::utils::PassRate::Statuses::FAILED; + errorMessage = ex.what(); + } + } catch (...) { + status = ov::test::utils::PassRate::Statuses::FAILED; + errorMessage = "Unknown failure occurred."; + } + summary.updateOPsStats(function, status, rel_influence_coef); + if (status != ov::test::utils::PassRate::Statuses::PASSED) { + GTEST_FATAL_FAILURE_(errorMessage.c_str()); + } + } else if (jmpRes == ov::test::utils::JMP_STATUS::anyError) { + OPENVINO_THROW("Crash happens"); + } else if (jmpRes == ov::test::utils::JMP_STATUS::alarmErr) { + summary.updateOPsStats(function, ov::test::utils::PassRate::Statuses::HANGED, rel_influence_coef); + OPENVINO_THROW("Crash happens"); + } + } +}; + +class GroupNormalizationFusionSubgraphTestsF_f32 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_f16 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_bf16 + : public GroupNormalizationFusionSubgraphTestsF {}; + +class GroupNormalizationFusionSubgraphTestsF_u8 : public GroupNormalizationFusionSubgraphTestsF {}; +class GroupNormalizationFusionSubgraphTestsF_u16 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_u32 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_u64 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_i8 : public GroupNormalizationFusionSubgraphTestsF {}; +class GroupNormalizationFusionSubgraphTestsF_i16 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_i32 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_i64 : public GroupNormalizationFusionSubgraphTestsF { +}; +class GroupNormalizationFusionSubgraphTestsF_f8e4m3 + : public GroupNormalizationFusionSubgraphTestsF {}; +class GroupNormalizationFusionSubgraphTestsF_f8e5m2 + : public GroupNormalizationFusionSubgraphTestsF {}; +class GroupNormalizationFusionSubgraphTestsF_f4e2m1 + : public GroupNormalizationFusionSubgraphTestsF {}; +class GroupNormalizationFusionSubgraphTestsF_f8e8m0 + : public GroupNormalizationFusionSubgraphTestsF {}; + +} // namespace test +} // namespace ov From 538045f1a8fb755b95629c31eb65c2aa72fe7741 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:33:35 +0100 Subject: [PATCH 29/45] Add instance of GroupNormalizationFusion shared functional subgraph test for GPU --- .../group_normalization_fusion.cpp | 421 ++++++++++++++++++ 1 file changed, 421 insertions(+) create mode 100644 src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/group_normalization_fusion.cpp diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/group_normalization_fusion.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/group_normalization_fusion.cpp new file mode 100644 index 00000000000000..2b7d69e2585b97 --- /dev/null +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/group_normalization_fusion.cpp @@ -0,0 +1,421 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "subgraph_tests/group_normalization_fusion.hpp" + +#include "common_test_utils/test_constants.hpp" + +using namespace ov::test; + +namespace { + +using GroupNormalizationFusionTransformationTestAdditionalValues = + std::tuple; // reference device properties + +std::vector valid_vals = { + std::make_tuple(ov::PartialShape{1, 320}, ov::Shape{}, ov::Shape{}, ov::Shape{320}, ov::Shape{320}, 1, 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 2}, + ov::Shape{1, 1, 1}, + ov::Shape{1, 1, 1}, + ov::Shape{320, 1, 1}, + ov::Shape{1, 320, 1, 1}, + 1, + 1e-5f), + std::make_tuple(ov::PartialShape{5, 320, 2, 2, 2}, + ov::Shape{1, 320, 1}, + ov::Shape{1, 320, 1}, + ov::Shape{320, 1, 1, 1}, + ov::Shape{320, 1, 1, 1}, + 320, + 1e-5f), + std::make_tuple(ov::PartialShape{ov::Dimension::dynamic(), + 320, + ov::Dimension::dynamic(), + ov::Dimension::dynamic(), + ov::Dimension::dynamic()}, + ov::Shape{1, 320, 1}, + ov::Shape{1, 320, 1}, + ov::Shape{320, 1, 1, 1}, + ov::Shape{320, 1, 1, 1}, + 320, + 1e-5f), + std::make_tuple(ov::PartialShape{3, 320}, + ov::Shape{32, 1}, + ov::Shape{32, 1}, + ov::Shape{320}, + ov::Shape{320}, + 32, + 1e-5f), + std::make_tuple(ov::PartialShape{2, 9, 4, 5, 6}, + ov::Shape{3, 1}, + ov::Shape{3, 1}, + ov::Shape{1, 9, 1, 1, 1}, + ov::Shape{1, 9, 1, 1, 1}, + 3, + 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 4}, + ov::Shape{1, 32, 1}, + ov::Shape{1, 32, 1}, + ov::Shape{320, 1, 1}, + ov::Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(ov::PartialShape{8, 320, 4, 8}, + ov::Shape{}, + ov::Shape{}, + ov::Shape{320, 1, 1}, + ov::Shape{1, 320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(ov::PartialShape{1, 512, 4, 8}, + ov::Shape{}, + ov::Shape{1, 128, 1}, + ov::Shape{1, 512, 1, 1}, + ov::Shape{512, 1, 1}, + 128, + 1e-6f), + std::make_tuple(ov::PartialShape{1, 192, 2, 2}, + ov::Shape{1, 64, 1}, + ov::Shape{}, + ov::Shape{1, 192, 1, 1}, + ov::Shape{1, 192, 1, 1}, + 64, + 1e-6f)}; + +std::vector invalid_vals = { + std::make_tuple(ov::PartialShape{1, 320}, ov::Shape{}, ov::Shape{}, ov::Shape{}, ov::Shape{}, 1, 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 2}, + ov::Shape{1, 1, 1}, + ov::Shape{1, 1, 1}, + ov::Shape{1, 1, 1}, + ov::Shape{1, 1, 1, 1}, + 1, + 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 2}, + ov::Shape{}, + ov::Shape{}, + ov::Shape{320, 1, 1}, + ov::Shape{}, + 1, + 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 2}, + ov::Shape{}, + ov::Shape{}, + ov::Shape{}, + ov::Shape{1, 320, 1, 1}, + 1, + 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 2}, + ov::Shape{1, 1, 1}, + ov::Shape{1, 32, 1}, + ov::Shape{320, 1, 1}, + ov::Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(ov::PartialShape{1, 320, 2, 2}, + ov::Shape{1, 32, 1}, + ov::Shape{1, 1, 1}, + ov::Shape{320, 1, 1}, + ov::Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(ov::PartialShape{ov::Dimension::dynamic(), 512, ov::Dimension::dynamic(), ov::Dimension::dynamic()}, + ov::Shape{}, + ov::Shape{}, + ov::Shape{1, 512, 1, 1}, + ov::Shape{1, 512, 1, 1}, + 100, + 1e-6f)}; + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphPositiveTests_f32, + GroupNormalizationFusionSubgraphTestsF_f32, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + true, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphPositiveTests_f16, + GroupNormalizationFusionSubgraphTestsF_f16, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + true, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphPositiveTests_bf16, + GroupNormalizationFusionSubgraphTestsF_bf16, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + true, + ov::test::utils::DEVICE_GPU, + {{ov::hint::inference_precision(ov::element::f16)}}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_bf16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTests_f32, + GroupNormalizationFusionSubgraphTestsF_f32, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTests_f16, + GroupNormalizationFusionSubgraphTestsF_f16, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTests_bf16, + GroupNormalizationFusionSubgraphTestsF_bf16, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {{ov::hint::inference_precision(ov::element::f16)}}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_bf16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_u8, + GroupNormalizationFusionSubgraphTestsF_u8, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_u16, + GroupNormalizationFusionSubgraphTestsF_u16, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_u32, + GroupNormalizationFusionSubgraphTestsF_u32, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_u64, + GroupNormalizationFusionSubgraphTestsF_u64, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u64::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_i8, + GroupNormalizationFusionSubgraphTestsF_i8, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_i8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_i16, + GroupNormalizationFusionSubgraphTestsF_i16, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_i16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_i32, + GroupNormalizationFusionSubgraphTestsF_i32, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_i32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_f8e5m2, + GroupNormalizationFusionSubgraphTestsF_f8e5m2, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f8e5m2::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_f4e2m1, + GroupNormalizationFusionSubgraphTestsF_f4e2m1, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f4e2m1::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_f8e8m0, + GroupNormalizationFusionSubgraphTestsF_f8e8m0, + ValuesIn(expand_vals(valid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f8e8m0::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_u8, + GroupNormalizationFusionSubgraphTestsF_u8, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_u16, + GroupNormalizationFusionSubgraphTestsF_u16, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_u32, + GroupNormalizationFusionSubgraphTestsF_u32, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_u64, + GroupNormalizationFusionSubgraphTestsF_u64, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_u64::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_i8, + GroupNormalizationFusionSubgraphTestsF_i8, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_i8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_i16, + GroupNormalizationFusionSubgraphTestsF_i16, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_i16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_i32, + GroupNormalizationFusionSubgraphTestsF_i32, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_i32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInalidVals_f8e5m2, + GroupNormalizationFusionSubgraphTestsF_f8e5m2, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f8e5m2::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_f4e2m1, + GroupNormalizationFusionSubgraphTestsF_f4e2m1, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f4e2m1::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsInvalidVals_f8e8m0, + GroupNormalizationFusionSubgraphTestsF_f8e8m0, + ValuesIn(expand_vals(invalid_vals, + GroupNormalizationFusionTransformationTestAdditionalValues( + false, + ov::test::utils::DEVICE_GPU, + {}, + ov::test::utils::DEVICE_TEMPLATE, + {{"DISABLE_TRANSFORMATIONS", true}}))), + GroupNormalizationFusionSubgraphTestsF_f8e8m0::getTestCaseName); + +} // namespace From a4ad12cb66b481919b515fcc322e9546620a2c44 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Tue, 4 Feb 2025 16:40:27 +0100 Subject: [PATCH 30/45] Refactor GroupNormalizationFusion transformation test --- .../group_normalization_fusion_tests.cpp | 919 +++++++++--------- 1 file changed, 441 insertions(+), 478 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index cc2474353f64d9..abcac8b84011fa 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -4,564 +4,527 @@ #include -#include "common_test_utils/data_utils.hpp" -#include "common_test_utils/ov_test_utils.hpp" -#include "openvino/core/model.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/group_normalization.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/mvn.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/pass/manager.hpp" -#include "openvino/pass/serialize.hpp" +#include "shared_test_classes/subgraph/group_normalization_fusion.hpp" #include "transformations/common_optimizations/group_normalization_fusion.hpp" -#include "transformations/init_node_info.hpp" using namespace testing; using namespace ov; -using ValuesContainerWithPositiveTestFlag = - std::tuple; // epsilon - -template -class GroupNormalizationFusionTestsFixture : public TestWithParam { -public: - static constexpr element::Type_t T_act_elem_t = T_act_elem; - static constexpr element::Type_t T_gn_gamma_elem_t = T_gn_gamma_elem; - static constexpr element::Type_t T_gn_beta_elem_t = T_gn_beta_elem; - static constexpr element::Type_t T_in_gamma_elem_t = T_in_gamma_elem; - static constexpr element::Type_t T_in_beta_elem_t = T_in_beta_elem; - - typedef typename ov::element_type_traits::value_type T_act_store_t; - typedef typename ov::element_type_traits::value_type T_gn_gamma_store_t; - typedef typename ov::element_type_traits::value_type T_gn_beta_store_t; - typedef typename ov::element_type_traits::value_type T_in_gamma_store_t; - typedef typename ov::element_type_traits::value_type T_in_beta_store_t; - - void TestBody() override { - auto params = GetParam(); - auto positive_test = std::get<0>(params); - auto data_shape = std::get<1>(params); - ASSERT_TRUE(data_shape[1].is_static()); - auto num_channels = static_cast(data_shape[1].get_max_length()); - auto instance_norm_gamma_shape = std::get<2>(params); - auto instance_norm_beta_shape = std::get<3>(params); - auto group_norm_gamma_shape = std::get<4>(params); - auto group_norm_beta_shape = std::get<5>(params); - auto num_groups = std::get<6>(params); - auto epsilon = std::get<7>(params); - - if (positive_test) { - if ((instance_norm_gamma_shape != Shape{}) && (shape_size(instance_norm_gamma_shape) != num_groups)) - FAIL() << "Unexpected shape of instance norm beta - expected either empty shape (which means that it " - "will not be put in the graph) or shape with exactly num_groups elements that can be " - "merged with the result of MVN."; - - if ((instance_norm_beta_shape != Shape{}) && (shape_size(instance_norm_beta_shape) != num_groups)) - FAIL() << "Unexpected shape of instance norm beta - expected either empty shape (which means that it " - "will not be put in the graph) or shape with exactly num_groups elements that can be " - "merged with the result of MVN."; - - if (shape_size(group_norm_gamma_shape) != num_channels) - FAIL() - << "Unexpected shape of group norm gamma - expected shape with exactly num_channels elements that " - "can be merged with the result of instance norm."; - - if (shape_size(group_norm_beta_shape) != num_channels) - FAIL() - << "Unexpected shape of group norm beta - expected shape with exactly num_channels elements that " - "can be merged with the result of instance norm."; - } - auto instance_norm_gamma_present = (instance_norm_gamma_shape != Shape{}); - auto instance_norm_beta_present = (instance_norm_beta_shape != Shape{}); - auto group_norm_beta_present = (group_norm_beta_shape != Shape{}); - auto group_norm_gamma_present = (group_norm_gamma_shape != Shape{}); - - if (positive_test) { - instance_norm_gamma_present = - instance_norm_gamma_present && (shape_size(instance_norm_gamma_shape) == num_groups); - instance_norm_beta_present = - instance_norm_beta_present && (shape_size(instance_norm_beta_shape) == num_groups); - group_norm_beta_present = group_norm_beta_present && (shape_size(group_norm_beta_shape) == num_channels); - group_norm_gamma_present = group_norm_gamma_present && (shape_size(group_norm_gamma_shape) == num_channels); + float, // epsilon + bool>; // whether it's a positive test that should run reference model or a negative test + +template +class GroupNormalizationFusionTransformationTestsF + : public ov::test::GroupNormalizationFusionTestBase, + public testing::TestWithParam { +protected: + bool positiveTest; + ov::pass::Manager manager; + ov::pass::Manager manager_ref; + std::shared_ptr model; + std::shared_ptr model_ref; + + virtual void read_test_parameters() { + const auto& params = GetParam(); + + dataShape = std::get<0>(params); + if (!dataShape.rank().is_static()) + throw std::runtime_error("Rank of input tensor has to be static!"); + if (dataShape.rank().get_max_length() < 2) + throw std::runtime_error("Expected at least two dimensions in input tensor!"); + if (!dataShape[1].is_static()) + throw std::runtime_error("Channel dimension in input tensor has to be static!"); + + numChannels = static_cast(dataShape[1].get_max_length()); + instanceNormGammaShape = std::get<1>(params); + instanceNormBetaShape = std::get<2>(params); + groupNormGammaShape = std::get<3>(params); + groupNormBetaShape = std::get<4>(params); + numGroups = std::get<5>(params); + epsilon = std::get<6>(params); + positiveTest = std::get<7>(params); + + instanceNormGammaPresent = (instanceNormGammaShape != Shape{}); + instanceNormBetaPresent = (instanceNormBetaShape != Shape{}); + + if (positiveTest) { + if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups)) + throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " + "exactly elements"); + if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups)) + throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " + "exactly elements"); + if (shape_size(groupNormGammaShape) != numChannels) + throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); + if (shape_size(groupNormBetaShape) != numChannels) + throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); + + instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups); + instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups); } + } - auto instance_norm_gamma_vals = std::vector(); - if (instance_norm_gamma_present) - instance_norm_gamma_vals = - test::utils::generateVector(shape_size(instance_norm_gamma_shape)); - - auto instance_norm_beta_vals = std::vector(); - if (instance_norm_beta_present) - instance_norm_beta_vals = - test::utils::generateVector(shape_size(instance_norm_beta_shape)); - - auto group_norm_gamma_vals = std::vector(); - if (group_norm_gamma_present) - group_norm_gamma_vals = test::utils::generateVector(shape_size(group_norm_gamma_shape)); - - auto group_norm_beta_vals = std::vector(); - if (group_norm_beta_present) - group_norm_beta_vals = test::utils::generateVector(shape_size(group_norm_beta_shape)); - - std::shared_ptr model(nullptr), model_ref(nullptr); - { - auto input = std::make_shared(T_act_elem_t, data_shape); - auto pre_mvn_shape_const = op::v0::Constant::create(element::i64, - Shape{3}, - {0, static_cast(num_groups), -1}); - auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); - - auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {1}); - auto mvn = std::make_shared(pre_mvn_reshape, - mvn_axes_const, - true, - epsilon, - op::MVNEpsMode::INSIDE_SQRT); - - std::shared_ptr opt_instance_norm_gamma_multiply = mvn; - if (instance_norm_gamma_present) { - auto instance_norm_gamma_const = - op::v0::Constant::create(T_in_gamma_elem_t, instance_norm_gamma_shape, instance_norm_gamma_vals); - opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); - } - - std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; - if (instance_norm_beta_present) { - auto instance_norm_beta_const = - op::v0::Constant::create(T_in_beta_elem_t, instance_norm_beta_shape, instance_norm_beta_vals); - opt_instance_norm_beta_add = - std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); - } - - auto post_instance_norm_shape = std::make_shared(input); - - auto post_instance_norm_reshape = - std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); - - std::shared_ptr opt_group_norm_gamma_multiply = post_instance_norm_reshape; - if (group_norm_gamma_present) { - auto group_norm_gamma_const = - op::v0::Constant::create(T_gn_gamma_elem_t, group_norm_gamma_shape, group_norm_gamma_vals); - opt_group_norm_gamma_multiply = - std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); - } - - std::shared_ptr opt_group_norm_beta_add = opt_group_norm_gamma_multiply; - if (group_norm_beta_present) { - auto group_norm_beta_const = - op::v0::Constant::create(T_gn_beta_elem_t, group_norm_beta_shape, group_norm_beta_vals); - opt_group_norm_beta_add = - std::make_shared(opt_group_norm_gamma_multiply, group_norm_beta_const); - } - - model = std::make_shared(NodeVector{opt_group_norm_beta_add}, ParameterVector{input}); - - pass::Manager m; - m.register_pass(); - OV_ASSERT_NO_THROW(m.run_passes(model)); - } + std::shared_ptr create_ref_model() { + auto input = std::make_shared(T_elem_t, dataShape); + + auto group_norm_beta_corr_vals = groupNormBetaVals; + if (instanceNormBetaPresent) + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] = + groupNormGammaVals[i] * instanceNormBetaVals[i / (numChannels / numGroups)] + groupNormBetaVals[i]; + auto group_norm_beta_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_beta_corr_vals); + + auto group_norm_gamma_corr_vals = groupNormGammaVals; + if (instanceNormGammaPresent) + for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + group_norm_gamma_corr_vals[i] = + groupNormGammaVals[i] * instanceNormGammaVals[i / (numChannels / numGroups)]; + auto group_norm_gamma_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_gamma_corr_vals); + + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + numGroups, + epsilon); + + return std::make_shared(NodeVector{group_norm}, ParameterVector{input}); + } - if (positive_test) { - auto input = std::make_shared(T_act_elem_t, data_shape); - - std::shared_ptr group_norm_beta_1d = nullptr; - std::shared_ptr group_norm_gamma_1d = nullptr; - - if (instance_norm_gamma_present) { - if (!group_norm_gamma_present) - group_norm_gamma_vals = std::vector(num_channels, 1); - auto group_norm_gamma_corr_vals = group_norm_gamma_vals; - for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) - group_norm_gamma_corr_vals[i] /= instance_norm_gamma_vals[i % num_groups]; - group_norm_gamma_1d = - op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_corr_vals); - if (instance_norm_beta_present) { - if (!group_norm_beta_present) - group_norm_beta_vals = std::vector(num_channels, 0); - auto group_norm_beta_corr_vals = group_norm_beta_vals; - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] -= - (group_norm_gamma_corr_vals[i] * instance_norm_beta_vals[i % num_groups]) / - instance_norm_gamma_vals[i % num_groups]; - group_norm_beta_1d = - op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); - } - } else { - if (instance_norm_beta_present) { - if (!group_norm_beta_present) - group_norm_beta_vals = std::vector(num_channels, 0); - auto group_norm_beta_corr_vals = group_norm_beta_vals; - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] -= - group_norm_gamma_vals[i] * instance_norm_beta_vals[i % num_groups]; - group_norm_beta_1d = - op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_corr_vals); - } - } - - if (group_norm_gamma_present) { - if (group_norm_gamma_1d == nullptr) { - group_norm_gamma_1d = - op::v0::Constant::create(T_gn_gamma_elem_t, Shape{num_channels}, group_norm_gamma_vals); - } - } else { - group_norm_gamma_1d = std::make_shared(T_gn_gamma_elem_t, Shape{num_channels}, 1); - } - - if (group_norm_beta_present) { - if (group_norm_beta_1d == nullptr) { - group_norm_beta_1d = - op::v0::Constant::create(T_gn_beta_elem_t, Shape{num_channels}, group_norm_beta_vals); - } - } else { - group_norm_beta_1d = std::make_shared(T_gn_beta_elem_t, Shape{num_channels}, 0); - } - - auto group_norm = std::make_shared(input, - group_norm_gamma_1d, - group_norm_beta_1d, - num_groups, - epsilon); - - model_ref = std::make_shared(NodeVector{group_norm}, ParameterVector{input}); - } +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + const auto& params = obj.param; + + const auto& data_shape = std::get<0>(params); + const auto& instance_norm_gamma_shape = std::get<1>(params); + const auto& instance_norm_beta_shape = std::get<2>(params); + const auto& group_norm_gamma_shape = std::get<3>(params); + const auto& group_norm_beta_shape = std::get<4>(params); + const auto& num_groups = std::get<5>(params); + const auto& epsilon = std::get<6>(params); + const auto& positive_test = std::get<7>(params); + + std::ostringstream results; + + results << "T=" << T_elem_t << "_"; + results << "Input=" << ov::test::utils::partialShape2str({data_shape}) << "_"; + results << "InstNormGamma=" << ov::test::utils::partialShape2str({instance_norm_gamma_shape}) << "_"; + results << "InstNormBeta=" << ov::test::utils::partialShape2str({instance_norm_beta_shape}) << "_"; + results << "GroupNormGamma=" << ov::test::utils::partialShape2str({group_norm_gamma_shape}) << "_"; + results << "GroupNormBeta=" << ov::test::utils::partialShape2str({group_norm_beta_shape}) << "_"; + results << "NumGroups=" << num_groups << "_"; + results << "Epsilon=" << epsilon << "_"; + results << "PositiveTest=" << std::boolalpha << positive_test << "_"; + + return results.str(); + } - if (positive_test) { - ASSERT_EQ(count_ops_of_type(model), 1); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::ACCURACY); - auto res = fc.compare(model, model_ref); - ASSERT_TRUE(res.valid) << res.message; + void run() { + read_test_parameters(); + generate_weights_init_values(); + model = create_model(); + + manager = ov::pass::Manager(); + manager.register_pass(); + manager.register_pass(); + OV_ASSERT_NO_THROW(manager.run_passes(model)); + + if (positiveTest) { + model_ref = create_ref_model(); + + manager_ref = ov::pass::Manager(); + manager_ref.register_pass(); + OV_ASSERT_NO_THROW(manager_ref.run_passes(model_ref)); + + const auto& f_parameters = model->get_parameters(); + const auto& f_ref_parameters = model_ref->get_parameters(); + ASSERT_EQ(f_parameters.size(), f_ref_parameters.size()); + ASSERT_EQ(f_parameters.size(), 1); + ASSERT_EQ(f_parameters[0]->outputs().size(), f_ref_parameters[0]->outputs().size()); + ASSERT_EQ(f_parameters[0]->outputs().size(), 1); + ASSERT_EQ(f_parameters[0]->get_element_type(), f_ref_parameters[0]->get_element_type()); + ASSERT_EQ(f_parameters[0]->get_element_type(), T_elem); + + const auto& f_results = model->get_results(); + const auto& f_ref_results = model_ref->get_results(); + ASSERT_EQ(f_results.size(), f_ref_results.size()); + ASSERT_EQ(f_results.size(), 1); + ASSERT_EQ(f_results[0]->outputs().size(), f_ref_results[0]->outputs().size()); + ASSERT_EQ(f_results[0]->outputs().size(), 1); + ASSERT_EQ(f_results[0]->inputs().size(), f_ref_results[0]->inputs().size()); + ASSERT_EQ(f_results[0]->inputs().size(), 1); + ASSERT_EQ(f_results[0]->get_element_type(), f_ref_results[0]->get_element_type()); + ASSERT_EQ(f_results[0]->get_element_type(), T_elem); + ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_ref_results[0]->get_output_partial_shape(0)); + ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_parameters[0]->get_output_partial_shape(0)); + ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), f_ref_parameters[0]->get_output_partial_shape(0)); + ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), dataShape); + + const auto& gn_node = f_results[0]->get_input_node_shared_ptr(0); + const auto& gn_ref_node = f_ref_results[0]->get_input_node_shared_ptr(0); + ASSERT_TRUE(ov::is_type(gn_node)); + ASSERT_TRUE(ov::is_type(gn_ref_node)); + ASSERT_EQ(gn_node->inputs().size(), gn_ref_node->inputs().size()); + ASSERT_EQ(gn_node->inputs().size(), 3); + ASSERT_EQ(gn_node->get_input_partial_shape(0), gn_ref_node->get_input_partial_shape(0)); + ASSERT_EQ(gn_node->get_input_partial_shape(0), dataShape); + ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), shape_size(gn_ref_node->get_input_shape(1))); + ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), numChannels); + ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), shape_size(gn_ref_node->get_input_shape(2))); + ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), numChannels); + + const auto& gn_node_casted = ov::as_type_ptr(gn_node); + const auto& gn_ref_node_casted = ov::as_type_ptr(gn_ref_node); + ASSERT_EQ(gn_node_casted->get_epsilon(), gn_ref_node_casted->get_epsilon()); + ASSERT_EQ(gn_node_casted->get_epsilon(), epsilon); + ASSERT_EQ(gn_node_casted->get_num_groups(), gn_ref_node_casted->get_num_groups()); + ASSERT_EQ(gn_node_casted->get_num_groups(), numGroups); } else { ASSERT_EQ(count_ops_of_type(model), 0); } } }; -class GroupNormalizationFusionTestsFixture_f16 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_bf16 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_f32 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_u8 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_u16 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_u32 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_u64 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_i8 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_i16 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_i32 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_i64 : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_f8e4m3 - : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_f8e5m2 - : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_f4e2m1 - : public GroupNormalizationFusionTestsFixture {}; -class GroupNormalizationFusionTestsFixture_f8e8m0 - : public GroupNormalizationFusionTestsFixture {}; - -TEST_P(GroupNormalizationFusionTestsFixture_f16, GroupNormalizationFusionTests_f16) { - GroupNormalizationFusionTestsFixture_f16::TestBody(); +class GroupNormalizationFusionTransformationTestsF_f32 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_f16 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_bf16 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_u8 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_u16 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_u32 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_u64 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_i8 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_i16 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_i32 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_i64 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_f8e4m3 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_f8e5m2 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_f4e2m1 + : public GroupNormalizationFusionTransformationTestsF {}; +class GroupNormalizationFusionTransformationTestsF_f8e8m0 + : public GroupNormalizationFusionTransformationTestsF {}; + +TEST_P(GroupNormalizationFusionTransformationTestsF_f32, GroupNormalizationFusionTransformationTests_f32) { + GroupNormalizationFusionTransformationTestsF_f32::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_bf16, GroupNormalizationFusionTests_bf16) { - GroupNormalizationFusionTestsFixture_bf16::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_f16, GroupNormalizationFusionTransformationTests_f16) { + GroupNormalizationFusionTransformationTestsF_f16::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_f32, GroupNormalizationFusionTests_f32) { - GroupNormalizationFusionTestsFixture_f32::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_bf16, GroupNormalizationFusionTransformationTests_bf16) { + GroupNormalizationFusionTransformationTestsF_bf16::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_u8, GroupNormalizationFusionTests_u8) { - GroupNormalizationFusionTestsFixture_u8::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_u8, GroupNormalizationFusionTransformationTests_u8) { + GroupNormalizationFusionTransformationTestsF_u8::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_u16, GroupNormalizationFusionTests_u16) { - GroupNormalizationFusionTestsFixture_u16::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_u16, GroupNormalizationFusionTransformationTests_u16) { + GroupNormalizationFusionTransformationTestsF_u16::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_u32, GroupNormalizationFusionTests_u32) { - GroupNormalizationFusionTestsFixture_u32::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_u32, GroupNormalizationFusionTransformationTests_u32) { + GroupNormalizationFusionTransformationTestsF_u32::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_u64, GroupNormalizationFusionTests_u64) { - GroupNormalizationFusionTestsFixture_u64::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_u64, GroupNormalizationFusionTransformationTests_u64) { + GroupNormalizationFusionTransformationTestsF_u64::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_i8, GroupNormalizationFusionTests_i8) { - GroupNormalizationFusionTestsFixture_i8::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_i8, GroupNormalizationFusionTransformationTests_i8) { + GroupNormalizationFusionTransformationTestsF_i8::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_i16, GroupNormalizationFusionTests_i16) { - GroupNormalizationFusionTestsFixture_i16::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_i16, GroupNormalizationFusionTransformationTests_i16) { + GroupNormalizationFusionTransformationTestsF_i16::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_i32, GroupNormalizationFusionTests_i32) { - GroupNormalizationFusionTestsFixture_i32::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_i32, GroupNormalizationFusionTransformationTests_i32) { + GroupNormalizationFusionTransformationTestsF_i32::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_i64, GroupNormalizationFusionTests_i64) { - GroupNormalizationFusionTestsFixture_i64::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_i64, GroupNormalizationFusionTransformationTests_i64) { + GroupNormalizationFusionTransformationTestsF_i64::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_f8e4m3, GroupNormalizationFusionTests_f8e4m3) { - GroupNormalizationFusionTestsFixture_f8e4m3::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_f8e4m3, GroupNormalizationFusionTransformationTests_f8e4m3) { + GroupNormalizationFusionTransformationTestsF_f8e4m3::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_f8e5m2, GroupNormalizationFusionTests_f8e5m2) { - GroupNormalizationFusionTestsFixture_f8e5m2::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_f8e5m2, GroupNormalizationFusionTransformationTests_f8e5m2) { + GroupNormalizationFusionTransformationTestsF_f8e5m2::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_f4e2m1, GroupNormalizationFusionTests_f4e2m1) { - GroupNormalizationFusionTestsFixture_f4e2m1::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_f4e2m1, GroupNormalizationFusionTransformationTests_f4e2m1) { + GroupNormalizationFusionTransformationTestsF_f4e2m1::run(); } -TEST_P(GroupNormalizationFusionTestsFixture_f8e8m0, GroupNormalizationFusionTests_f8e8m0) { - GroupNormalizationFusionTestsFixture_f8e8m0::TestBody(); +TEST_P(GroupNormalizationFusionTransformationTestsF_f8e8m0, GroupNormalizationFusionTransformationTests_f8e8m0) { + GroupNormalizationFusionTransformationTestsF_f8e8m0::run(); } -using RawValuesContainer = std::tuple; +using GroupNormalizationFusionSubgraphTestAdditionalValues = + std::tuple; // whether it's a positive test that should run reference model or a negative test -std::vector valid_vals = { +std::vector valid_vals = { std::make_tuple(PartialShape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), - std::make_tuple(Shape{1, 320, 2, 2}, + std::make_tuple(PartialShape{1, 320, 2, 2}, Shape{1, 1, 1}, Shape{1, 1, 1}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 1, 1e-5f), - std::make_tuple(PartialShape{1, 320, 2, 2, 2}, + std::make_tuple(PartialShape{5, 320, 2, 2, 2}, Shape{1, 320, 1}, Shape{1, 320, 1}, Shape{320, 1, 1, 1}, Shape{320, 1, 1, 1}, 320, 1e-5f), - std::make_tuple(PartialShape{Dimension::dynamic(), - 320, - Dimension::dynamic(), - Dimension::dynamic(), - Dimension::dynamic(), - Dimension::dynamic()}, - Shape{1, 320, 1}, - Shape{1, 320, 1}, - Shape{320, 1, 1, 1, 1}, - Shape{320, 1, 1, 1, 1}, - 320, - 1e-5f), - std::make_tuple(PartialShape{Dimension::dynamic(), 320}, - Shape{32, 1}, - Shape{32, 1}, - Shape{320}, - Shape{320}, - 32, + std::make_tuple( + PartialShape{Dimension::dynamic(), 320, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}, + Shape{1, 320, 1}, + Shape{1, 320, 1}, + Shape{320, 1, 1, 1}, + Shape{320, 1, 1, 1}, + 320, + 1e-5f), + std::make_tuple(PartialShape{3, 320}, Shape{32, 1}, Shape{32, 1}, Shape{320}, Shape{320}, 32, 1e-5f), + std::make_tuple(PartialShape{2, 9, 4, 5, 6}, + Shape{3, 1}, + Shape{3, 1}, + Shape{1, 9, 1, 1, 1}, + Shape{1, 9, 1, 1, 1}, + 3, 1e-5f), - std::make_tuple(PartialShape{1, 320, Dimension::dynamic()}, - Shape{32, 1}, - Shape{32, 1}, - Shape{320, 1}, - Shape{320, 1}, - 32, - 1e-5f), - std::make_tuple(PartialShape{1, 320, 2, Dimension::dynamic()}, + std::make_tuple(PartialShape{1, 320, 2, 4}, Shape{1, 32, 1}, Shape{1, 32, 1}, Shape{320, 1, 1}, Shape{320, 1, 1}, 32, 1e-5f), - std::make_tuple(PartialShape{2, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), - std::make_tuple(PartialShape{1, 512, Dimension::dynamic(), Dimension::dynamic()}, + std::make_tuple(PartialShape{8, 320, 4, 8}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{1, 320, 1, 1}, 32, 1e-5f), + std::make_tuple(PartialShape{1, 512, 4, 8}, Shape{}, Shape{1, 128, 1}, Shape{1, 512, 1, 1}, Shape{512, 1, 1}, 128, 1e-6f), - std::make_tuple(PartialShape{1, 512, 2, 2}, + std::make_tuple(PartialShape{1, 192, 2, 2}, Shape{1, 64, 1}, Shape{}, - Shape{1, 512, 1, 1}, - Shape{1, 512, 1, 1}, + Shape{1, 192, 1, 1}, + Shape{1, 192, 1, 1}, 64, 1e-6f)}; -auto invalid_vals = - Values(std::make_tuple(false, PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), - std::make_tuple(false, - PartialShape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{1, 1, 1}, - Shape{1, 1, 1, 1}, - 1, - 1e-5f), - std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), - std::make_tuple(false, PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), - std::make_tuple(false, - PartialShape{1, 320, 2, 2}, - Shape{1, 1, 1}, - Shape{1, 32, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(false, - PartialShape{1, 320, 2, 2}, - Shape{1, 32, 1}, - Shape{1, 1, 1}, - Shape{320, 1, 1}, - Shape{320, 1, 1}, - 32, - 1e-5f), - std::make_tuple(false, - PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, - Shape{}, - Shape{}, - Shape{1, 512, 1, 1}, - Shape{1, 512, 1, 1}, - 100, - 1e-6f)); - -std::vector add_positive_test_flag_to_vals( - const bool positive_test, - const std::vector& vals) { - std::vector res; - for (const RawValuesContainer& t : vals) { - auto new_val = std::tuple_cat(std::tuple(positive_test), t); - res.push_back(new_val); - } - return res; -} - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_f16, - GroupNormalizationFusionTestsFixture_f16, - ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f16, - GroupNormalizationFusionTestsFixture_f16, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_bf16, - GroupNormalizationFusionTestsFixture_bf16, - ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_bf16, - GroupNormalizationFusionTestsFixture_bf16, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionPositiveTests_f32, - GroupNormalizationFusionTestsFixture_f32, - ValuesIn(add_positive_test_flag_to_vals(true, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTests_f32, - GroupNormalizationFusionTestsFixture_f32, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u8, - GroupNormalizationFusionTestsFixture_u8, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u8, - GroupNormalizationFusionTestsFixture_u8, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u16, - GroupNormalizationFusionTestsFixture_u16, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u16, - GroupNormalizationFusionTestsFixture_u16, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u32, - GroupNormalizationFusionTestsFixture_u32, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u32, - GroupNormalizationFusionTestsFixture_u32, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_u64, - GroupNormalizationFusionTestsFixture_u64, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_u64, - GroupNormalizationFusionTestsFixture_u64, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i8, - GroupNormalizationFusionTestsFixture_i8, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i8, - GroupNormalizationFusionTestsFixture_i8, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i16, - GroupNormalizationFusionTestsFixture_i16, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i16, - GroupNormalizationFusionTestsFixture_i16, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i32, - GroupNormalizationFusionTestsFixture_i32, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i32, - GroupNormalizationFusionTestsFixture_i32, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_i64, - GroupNormalizationFusionTestsFixture_i64, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_i64, - GroupNormalizationFusionTestsFixture_i64, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e4m3, - GroupNormalizationFusionTestsFixture_f8e4m3, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e4m3, - GroupNormalizationFusionTestsFixture_f8e4m3, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e5m2, - GroupNormalizationFusionTestsFixture_f8e5m2, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e5m2, - GroupNormalizationFusionTestsFixture_f8e5m2, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f4e2m1, - GroupNormalizationFusionTestsFixture_f4e2m1, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f4e2m1, - GroupNormalizationFusionTestsFixture_f4e2m1, - invalid_vals); - -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsValidVals_f8e8m0, - GroupNormalizationFusionTestsFixture_f8e8m0, - ValuesIn(add_positive_test_flag_to_vals(false, valid_vals))); +std::vector invalid_vals = { + std::make_tuple(PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), + std::make_tuple(PartialShape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1}, + Shape{1, 1, 1, 1}, + 1, + 1e-5f), + std::make_tuple(PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{320, 1, 1}, Shape{}, 1, 1e-5f), + std::make_tuple(PartialShape{1, 320, 2, 2}, Shape{}, Shape{}, Shape{}, Shape{1, 320, 1, 1}, 1, 1e-5f), + std::make_tuple(PartialShape{1, 320, 2, 2}, + Shape{1, 1, 1}, + Shape{1, 32, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(PartialShape{1, 320, 2, 2}, + Shape{1, 32, 1}, + Shape{1, 1, 1}, + Shape{320, 1, 1}, + Shape{320, 1, 1}, + 32, + 1e-5f), + std::make_tuple(PartialShape{Dimension::dynamic(), 512, Dimension::dynamic(), Dimension::dynamic()}, + Shape{}, + Shape{}, + Shape{1, 512, 1, 1}, + Shape{1, 512, 1, 1}, + 100, + 1e-6f)}; -INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionNegativeTestsInvalidVals_f8e8m0, - GroupNormalizationFusionTestsFixture_f8e8m0, - invalid_vals); +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationPositiveTests_f32, + GroupNormalizationFusionTransformationTestsF_f32, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(true))), + GroupNormalizationFusionTransformationTestsF_f32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationPositiveTests_f16, + GroupNormalizationFusionTransformationTestsF_f16, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(true))), + GroupNormalizationFusionTransformationTestsF_f16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationPositiveTests_bf16, + GroupNormalizationFusionTransformationTestsF_bf16, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(true))), + GroupNormalizationFusionTransformationTestsF_bf16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTests_f32, + GroupNormalizationFusionTransformationTestsF_f32, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTests_f16, + GroupNormalizationFusionTransformationTestsF_f16, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTests_bf16, + GroupNormalizationFusionTransformationTestsF_bf16, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_bf16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u8, + GroupNormalizationFusionTransformationTestsF_u8, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u16, + GroupNormalizationFusionTransformationTestsF_u16, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u32, + GroupNormalizationFusionTransformationTestsF_u32, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u64, + GroupNormalizationFusionTransformationTestsF_u64, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u64::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_i8, + GroupNormalizationFusionTransformationTestsF_i8, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_i8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_i16, + GroupNormalizationFusionTransformationTestsF_i16, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_i16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_i32, + GroupNormalizationFusionTransformationTestsF_i32, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_i32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_f8e5m2, + GroupNormalizationFusionTransformationTestsF_f8e5m2, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f8e5m2::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_f4e2m1, + GroupNormalizationFusionTransformationTestsF_f4e2m1, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f4e2m1::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_f8e8m0, + GroupNormalizationFusionTransformationTestsF_f8e8m0, + ValuesIn(ov::test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f8e8m0::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u8, + GroupNormalizationFusionTransformationTestsF_u8, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u16, + GroupNormalizationFusionTransformationTestsF_u16, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u32, + GroupNormalizationFusionTransformationTestsF_u32, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u64, + GroupNormalizationFusionTransformationTestsF_u64, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_u64::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_i8, + GroupNormalizationFusionTransformationTestsF_i8, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_i8::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_i16, + GroupNormalizationFusionTransformationTestsF_i16, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_i16::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_i32, + GroupNormalizationFusionTransformationTestsF_i32, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_i32::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInalidVals_f8e5m2, + GroupNormalizationFusionTransformationTestsF_f8e5m2, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f8e5m2::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_f4e2m1, + GroupNormalizationFusionTransformationTestsF_f4e2m1, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f4e2m1::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_f8e8m0, + GroupNormalizationFusionTransformationTestsF_f8e8m0, + ValuesIn(ov::test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + GroupNormalizationFusionTransformationTestsF_f8e8m0::getTestCaseName); \ No newline at end of file From c7c9c375f3d2d1ec5db2ac373bf735ebc8d64ec2 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 00:12:12 +0100 Subject: [PATCH 31/45] Add missing include file in GroupNormalizationFusion shared functional subgraph test --- .../shared_test_classes/subgraph/group_normalization_fusion.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index 61e40355541a4b..43a354935524aa 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -5,6 +5,7 @@ #pragma once #include "common_test_utils/data_utils.hpp" +#include "common_test_utils/ov_tensor_utils.hpp" #include "common_test_utils/ov_test_utils.hpp" #include "functional_test_utils/crash_handler.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" From d1d6a24d80ed72b02e4540b12d564b36cd217f13 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 00:14:18 +0100 Subject: [PATCH 32/45] Cosmetic changes in ov::test::SubgraphBaseTest class --- .../include/shared_test_classes/base/ov_subgraph.hpp | 2 +- .../functional/shared_test_classes/src/base/ov_subgraph.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/base/ov_subgraph.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/base/ov_subgraph.hpp index 1b4eb35689bd58..c71c67ae2562e4 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/base/ov_subgraph.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/base/ov_subgraph.hpp @@ -41,12 +41,12 @@ class SubgraphBaseTest : public ov::test::TestsCommon { virtual void validate(); virtual void configure_model(); virtual void generate_inputs(const std::vector& targetInputStaticShapes); + virtual void init_thresholds(); void compare_models_param_res(const std::shared_ptr& f, const std::shared_ptr& f_ref); void compare_nodes(const std::shared_ptr& node1, const std::shared_ptr& node2, std::ostream& err_log); void update_ref_model(); void match_parameters(const ov::ParameterVector& params, const ov::ParameterVector& ref_params); - void init_thresholds(); void init_input_shapes(const std::vector& shapes); void set_callback_exception(std::function callback) { callback_exception = callback; } diff --git a/src/tests/functional/shared_test_classes/src/base/ov_subgraph.cpp b/src/tests/functional/shared_test_classes/src/base/ov_subgraph.cpp index 00b5b10189b8e5..40ddccc734b4c7 100644 --- a/src/tests/functional/shared_test_classes/src/base/ov_subgraph.cpp +++ b/src/tests/functional/shared_test_classes/src/base/ov_subgraph.cpp @@ -504,7 +504,7 @@ void SubgraphBaseTest::validate() { } ASSERT_EQ(actualOutputs.size(), expectedOutputs.size()) - << "TEMPLATE plugin has " << expectedOutputs.size() << " outputs, while " << targetDevice << " " << actualOutputs.size(); + << "Reference has " << expectedOutputs.size() << " outputs, while " << targetDevice << " " << actualOutputs.size(); if (is_report_stages) { std::cout << "[ COMPARATION ] `ov_tensor_utils.hpp::compare()` is started"<< std::endl; } From 1bd213bdce0ff4781b77fbe2e982732c1b16b8fa Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 00:27:02 +0100 Subject: [PATCH 33/45] Remove redundant virtual keyword in GroupNormalizationFusion shared functional subraph test fixture class --- .../shared_test_classes/subgraph/group_normalization_fusion.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index 43a354935524aa..3ddc264000646b 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -350,7 +350,7 @@ class GroupNormalizationFusionSubgraphTestsF return outputs; } - virtual void generate_inputs(const std::vector& targetInputStaticShapes) override { + void generate_inputs(const std::vector& targetInputStaticShapes) override { inputs.clear(); auto itTargetShape = targetInputStaticShapes.begin(); From be969953ef2f3aa4b70ec2ae13dc5775bcacf148 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 05:14:46 +0100 Subject: [PATCH 34/45] Fix accessing type and members variables/functions from GroupNormalizationFusionTestBase in derived classes' templates --- .../group_normalization_fusion_tests.cpp | 171 +++++++++--------- .../subgraph/group_normalization_fusion.hpp | 53 +++--- 2 files changed, 119 insertions(+), 105 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index abcac8b84011fa..9f8598de5be10a 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -24,80 +24,8 @@ template class GroupNormalizationFusionTransformationTestsF : public ov::test::GroupNormalizationFusionTestBase, public testing::TestWithParam { -protected: - bool positiveTest; - ov::pass::Manager manager; - ov::pass::Manager manager_ref; - std::shared_ptr model; - std::shared_ptr model_ref; - - virtual void read_test_parameters() { - const auto& params = GetParam(); - - dataShape = std::get<0>(params); - if (!dataShape.rank().is_static()) - throw std::runtime_error("Rank of input tensor has to be static!"); - if (dataShape.rank().get_max_length() < 2) - throw std::runtime_error("Expected at least two dimensions in input tensor!"); - if (!dataShape[1].is_static()) - throw std::runtime_error("Channel dimension in input tensor has to be static!"); - - numChannels = static_cast(dataShape[1].get_max_length()); - instanceNormGammaShape = std::get<1>(params); - instanceNormBetaShape = std::get<2>(params); - groupNormGammaShape = std::get<3>(params); - groupNormBetaShape = std::get<4>(params); - numGroups = std::get<5>(params); - epsilon = std::get<6>(params); - positiveTest = std::get<7>(params); - - instanceNormGammaPresent = (instanceNormGammaShape != Shape{}); - instanceNormBetaPresent = (instanceNormBetaShape != Shape{}); - - if (positiveTest) { - if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups)) - throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " - "exactly elements"); - if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups)) - throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " - "exactly elements"); - if (shape_size(groupNormGammaShape) != numChannels) - throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); - if (shape_size(groupNormBetaShape) != numChannels) - throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); - - instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups); - instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups); - } - } - - std::shared_ptr create_ref_model() { - auto input = std::make_shared(T_elem_t, dataShape); - - auto group_norm_beta_corr_vals = groupNormBetaVals; - if (instanceNormBetaPresent) - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) - group_norm_beta_corr_vals[i] = - groupNormGammaVals[i] * instanceNormBetaVals[i / (numChannels / numGroups)] + groupNormBetaVals[i]; - auto group_norm_beta_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_beta_corr_vals); - - auto group_norm_gamma_corr_vals = groupNormGammaVals; - if (instanceNormGammaPresent) - for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) - group_norm_gamma_corr_vals[i] = - groupNormGammaVals[i] * instanceNormGammaVals[i / (numChannels / numGroups)]; - auto group_norm_gamma_1d = op::v0::Constant::create(T_elem_t, Shape{numChannels}, group_norm_gamma_corr_vals); - - auto group_norm = std::make_shared(input, - group_norm_gamma_1d, - group_norm_beta_1d, - numGroups, - epsilon); - - return std::make_shared(NodeVector{group_norm}, ParameterVector{input}); - } - public: + static constexpr element::Type T_elem_t = T_elem; static std::string getTestCaseName(const testing::TestParamInfo& obj) { const auto& params = obj.param; @@ -127,8 +55,8 @@ class GroupNormalizationFusionTransformationTestsF void run() { read_test_parameters(); - generate_weights_init_values(); - model = create_model(); + this->generate_weights_init_values(); + model = this->create_model(); manager = ov::pass::Manager(); manager.register_pass(); @@ -164,7 +92,7 @@ class GroupNormalizationFusionTransformationTestsF ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_ref_results[0]->get_output_partial_shape(0)); ASSERT_EQ(f_results[0]->get_output_partial_shape(0), f_parameters[0]->get_output_partial_shape(0)); ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), f_ref_parameters[0]->get_output_partial_shape(0)); - ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), dataShape); + ASSERT_EQ(f_ref_results[0]->get_output_partial_shape(0), this->dataShape); const auto& gn_node = f_results[0]->get_input_node_shared_ptr(0); const auto& gn_ref_node = f_ref_results[0]->get_input_node_shared_ptr(0); @@ -173,22 +101,103 @@ class GroupNormalizationFusionTransformationTestsF ASSERT_EQ(gn_node->inputs().size(), gn_ref_node->inputs().size()); ASSERT_EQ(gn_node->inputs().size(), 3); ASSERT_EQ(gn_node->get_input_partial_shape(0), gn_ref_node->get_input_partial_shape(0)); - ASSERT_EQ(gn_node->get_input_partial_shape(0), dataShape); + ASSERT_EQ(gn_node->get_input_partial_shape(0), this->dataShape); ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), shape_size(gn_ref_node->get_input_shape(1))); - ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), numChannels); + ASSERT_EQ(shape_size(gn_node->get_input_shape(1)), this->numChannels); ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), shape_size(gn_ref_node->get_input_shape(2))); - ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), numChannels); + ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), this->numChannels); const auto& gn_node_casted = ov::as_type_ptr(gn_node); const auto& gn_ref_node_casted = ov::as_type_ptr(gn_ref_node); ASSERT_EQ(gn_node_casted->get_epsilon(), gn_ref_node_casted->get_epsilon()); - ASSERT_EQ(gn_node_casted->get_epsilon(), epsilon); + ASSERT_EQ(gn_node_casted->get_epsilon(), this->epsilon); ASSERT_EQ(gn_node_casted->get_num_groups(), gn_ref_node_casted->get_num_groups()); - ASSERT_EQ(gn_node_casted->get_num_groups(), numGroups); + ASSERT_EQ(gn_node_casted->get_num_groups(), this->numGroups); } else { ASSERT_EQ(count_ops_of_type(model), 0); } } + +protected: + bool positiveTest; + ov::pass::Manager manager; + ov::pass::Manager manager_ref; + std::shared_ptr model; + std::shared_ptr model_ref; + + void read_test_parameters() override { + const auto& params = GetParam(); + + this->dataShape = std::get<0>(params); + if (!this->dataShape.rank().is_static()) + throw std::runtime_error("Rank of input tensor has to be static!"); + if (this->dataShape.rank().get_max_length() < 2) + throw std::runtime_error("Expected at least two dimensions in input tensor!"); + if (!this->dataShape[1].is_static()) + throw std::runtime_error("Channel dimension in input tensor has to be static!"); + + this->numChannels = static_cast(this->dataShape[1].get_max_length()); + this->instanceNormGammaShape = std::get<1>(params); + this->instanceNormBetaShape = std::get<2>(params); + this->groupNormGammaShape = std::get<3>(params); + this->groupNormBetaShape = std::get<4>(params); + this->numGroups = std::get<5>(params); + this->epsilon = std::get<6>(params); + positiveTest = std::get<7>(params); + + this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{}); + this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{}); + + if (positiveTest) { + if ((this->instanceNormGammaShape != Shape{}) && + (shape_size(this->instanceNormGammaShape) != this->numGroups)) + throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " + "exactly elements"); + if ((this->instanceNormBetaShape != Shape{}) && + (shape_size(this->instanceNormBetaShape) != this->numGroups)) + throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " + "exactly elements"); + if (shape_size(this->groupNormGammaShape) != this->numChannels) + throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); + if (shape_size(this->groupNormBetaShape) != this->numChannels) + throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); + + this->instanceNormGammaPresent = + this->instanceNormGammaPresent && (shape_size(this->instanceNormGammaShape) == this->numGroups); + this->instanceNormBetaPresent = + this->instanceNormBetaPresent && (shape_size(this->instanceNormBetaShape) == this->numGroups); + } + } + + std::shared_ptr create_ref_model() { + auto input = std::make_shared(T_elem_t, this->dataShape); + + auto group_norm_beta_corr_vals = this->groupNormBetaVals; + if (this->instanceNormBetaPresent) + for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + group_norm_beta_corr_vals[i] = + this->groupNormGammaVals[i] * + this->instanceNormBetaVals[i / (this->numChannels / this->numGroups)] + + this->groupNormBetaVals[i]; + auto group_norm_beta_1d = + op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_beta_corr_vals); + + auto group_norm_gamma_corr_vals = this->groupNormGammaVals; + if (this->instanceNormGammaPresent) + for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + group_norm_gamma_corr_vals[i] = this->groupNormGammaVals[i] * + this->instanceNormGammaVals[i / (this->numChannels / this->numGroups)]; + auto group_norm_gamma_1d = + op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_gamma_corr_vals); + + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + this->numGroups, + this->epsilon); + + return std::make_shared(NodeVector{group_norm}, ParameterVector{input}); + } }; class GroupNormalizationFusionTransformationTestsF_f32 diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index 3ddc264000646b..761e99daad3599 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -132,6 +132,7 @@ class GroupNormalizationFusionSubgraphTestsF public ov::test::SubgraphBaseTest, public testing::WithParamInterface { public: + static constexpr element::Type T_elem_t = T_elem; static std::string getTestCaseName( const testing::TestParamInfo& obj) { const auto& params = obj.param; @@ -194,32 +195,32 @@ class GroupNormalizationFusionSubgraphTestsF SubgraphBaseTest::TearDown(); } - virtual void read_test_parameters() { + void read_test_parameters() override { const auto& params = GetParam(); - dataShape = std::get<0>(params); - if (!dataShape.rank().is_static()) + this->dataShape = std::get<0>(params); + if (!this->dataShape.rank().is_static()) throw std::runtime_error("Rank of input tensor has to be static!"); - if (dataShape.rank().get_max_length() < 2) + if (this->dataShape.rank().get_max_length() < 2) throw std::runtime_error("Expected at least two dimensions in input tensor!"); - if (!dataShape[1].is_static()) + if (!this->dataShape[1].is_static()) throw std::runtime_error("Channel dimension in input tensor has to be static!"); - numChannels = static_cast(dataShape[1].get_max_length()); - instanceNormGammaShape = std::get<1>(params); - instanceNormBetaShape = std::get<2>(params); - groupNormGammaShape = std::get<3>(params); - groupNormBetaShape = std::get<4>(params); - numGroups = std::get<5>(params); - epsilon = std::get<6>(params); + this->numChannels = static_cast(this->dataShape[1].get_max_length()); + this->instanceNormGammaShape = std::get<1>(params); + this->instanceNormBetaShape = std::get<2>(params); + this->groupNormGammaShape = std::get<3>(params); + this->groupNormBetaShape = std::get<4>(params); + this->numGroups = std::get<5>(params); + this->epsilon = std::get<6>(params); positiveTest = std::get<7>(params); targetDeviceName = std::get<8>(params); targetConfiguration = std::get<9>(params); refDevice = std::get<10>(params); refConfiguration = std::get<11>(params); - instanceNormGammaPresent = (instanceNormGammaShape != Shape{}); - instanceNormBetaPresent = (instanceNormBetaShape != Shape{}); + this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{}); + this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{}); inType = T_elem_t; outType = T_elem_t; @@ -227,19 +228,23 @@ class GroupNormalizationFusionSubgraphTestsF configuration = targetConfiguration; if (positiveTest) { - if ((instanceNormGammaShape != Shape{}) && (shape_size(instanceNormGammaShape) != numGroups)) + if ((this->instanceNormGammaShape != Shape{}) && + (shape_size(this->instanceNormGammaShape) != this->numGroups)) throw std::runtime_error("Shape of instance norm gamma has to either be empty or contain " "exactly elements"); - if ((instanceNormBetaShape != Shape{}) && (shape_size(instanceNormBetaShape) != numGroups)) + if ((this->instanceNormBetaShape != Shape{}) && + (shape_size(this->instanceNormBetaShape) != this->numGroups)) throw std::runtime_error("Shape of instance norm beta has to either be empty shape or contain " "exactly elements"); - if (shape_size(groupNormGammaShape) != numChannels) + if (shape_size(this->groupNormGammaShape) != this->numChannels) throw std::runtime_error("Shape of group norm gamma has to contain exactly elements"); - if (shape_size(groupNormBetaShape) != numChannels) + if (shape_size(this->groupNormBetaShape) != this->numChannels) throw std::runtime_error("Shape of group norm beta has to contain exactly elements"); - instanceNormGammaPresent = instanceNormGammaPresent && (shape_size(instanceNormGammaShape) == numGroups); - instanceNormBetaPresent = instanceNormBetaPresent && (shape_size(instanceNormBetaShape) == numGroups); + this->instanceNormGammaPresent = + this->instanceNormGammaPresent && (shape_size(this->instanceNormGammaShape) == this->numGroups); + this->instanceNormBetaPresent = + this->instanceNormBetaPresent && (shape_size(this->instanceNormBetaShape) == this->numGroups); } } @@ -305,7 +310,7 @@ class GroupNormalizationFusionSubgraphTestsF void init_thresholds() override { if (!targetStaticShapes.empty()) { - size_t problem_size = shape_size(dataShape.get_shape()); + size_t problem_size = shape_size(this->dataShape.get_shape()); abs_threshold = pow(problem_size, 0.5) * test::utils::get_eps_by_ov_type(outType); rel_threshold = abs_threshold; @@ -399,8 +404,8 @@ class GroupNormalizationFusionSubgraphTestsF std::string errorMessage; try { read_test_parameters(); - generate_weights_init_values(); - functionRefs = create_model(); + this->generate_weights_init_values(); + functionRefs = this->create_model(); function = functionRefs->clone(); pass::Manager m; m.register_pass(); @@ -415,7 +420,7 @@ class GroupNormalizationFusionSubgraphTestsF if (!function->is_dynamic()) { configure_device(); configure_ref_device(); - auto input_shapes = static_partial_shapes_to_test_representation({dataShape}); + auto input_shapes = static_partial_shapes_to_test_representation({this->dataShape}); init_input_shapes(input_shapes); ASSERT_FALSE(targetStaticShapes.empty() && !function->get_parameters().empty()) << "Target Static Shape is empty!!!"; From 0fa9aa849dc75d80e82a8837687af4bec91382a3 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 14:57:33 +0100 Subject: [PATCH 35/45] Add missing override keywords in GroupNormalizationFusion shared functional subgraph test --- .../subgraph/group_normalization_fusion.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index 761e99daad3599..4365d96558f0bb 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -325,7 +325,7 @@ class GroupNormalizationFusionSubgraphTestsF refInferRequest.infer(); } - std::vector calculate_refs() { + std::vector calculate_refs() override { if (is_report_stages) { std::cout << "[ REFERENCE ] `GroupNormalizationFusionSubgraphTestsF::calculate_refs()` is started" << std::endl; @@ -378,7 +378,7 @@ class GroupNormalizationFusionSubgraphTestsF } public: - void run() { + void run() override { is_reported = true; bool isCurrentTestDisabled = ov::test::utils::current_test_is_disabled(); From b6f7826db632b1b2250b5e3ca80b811657a0aba5 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 15:00:38 +0100 Subject: [PATCH 36/45] Fix usage of ov::element::Type_t and ov::element::Type in GroupNormalizationFusion tests --- .../group_normalization_fusion_tests.cpp | 13 ++++----- .../subgraph/group_normalization_fusion.hpp | 28 +++++++++---------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 9f8598de5be10a..d84aad51d9a3e8 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -20,12 +20,12 @@ using GroupNormalizationFusionSubgraphTestValues = float, // epsilon bool>; // whether it's a positive test that should run reference model or a negative test -template +template class GroupNormalizationFusionTransformationTestsF - : public ov::test::GroupNormalizationFusionTestBase, + : public ov::test::GroupNormalizationFusionTestBase, public testing::TestWithParam { public: - static constexpr element::Type T_elem_t = T_elem; + static constexpr element::Type T_elem = T_elem_t; static std::string getTestCaseName(const testing::TestParamInfo& obj) { const auto& params = obj.param; @@ -170,7 +170,7 @@ class GroupNormalizationFusionTransformationTestsF } std::shared_ptr create_ref_model() { - auto input = std::make_shared(T_elem_t, this->dataShape); + auto input = std::make_shared(T_elem, this->dataShape); auto group_norm_beta_corr_vals = this->groupNormBetaVals; if (this->instanceNormBetaPresent) @@ -179,8 +179,7 @@ class GroupNormalizationFusionTransformationTestsF this->groupNormGammaVals[i] * this->instanceNormBetaVals[i / (this->numChannels / this->numGroups)] + this->groupNormBetaVals[i]; - auto group_norm_beta_1d = - op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_beta_corr_vals); + auto group_norm_beta_1d = op::v0::Constant::create(T_elem, Shape{this->numChannels}, group_norm_beta_corr_vals); auto group_norm_gamma_corr_vals = this->groupNormGammaVals; if (this->instanceNormGammaPresent) @@ -188,7 +187,7 @@ class GroupNormalizationFusionTransformationTestsF group_norm_gamma_corr_vals[i] = this->groupNormGammaVals[i] * this->instanceNormGammaVals[i / (this->numChannels / this->numGroups)]; auto group_norm_gamma_1d = - op::v0::Constant::create(T_elem_t, Shape{this->numChannels}, group_norm_gamma_corr_vals); + op::v0::Constant::create(T_elem, Shape{this->numChannels}, group_norm_gamma_corr_vals); auto group_norm = std::make_shared(input, group_norm_gamma_1d, diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index 4365d96558f0bb..c89e8e366ddb6a 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -50,10 +50,10 @@ std::vector> expand_vals(std::vector< return res; } -template +template class GroupNormalizationFusionTestBase { public: - static constexpr element::Type T_elem_t = T_elem; + static constexpr element::Type T_elem = T_elem_t; typedef typename ov::element_type_traits::value_type T_store_t; protected: @@ -86,7 +86,7 @@ class GroupNormalizationFusionTestBase { } std::shared_ptr create_model() { - auto input = std::make_shared(T_elem_t, dataShape); + auto input = std::make_shared(T_elem, dataShape); auto pre_mvn_shape_const = op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(numGroups), -1}); auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); @@ -98,14 +98,14 @@ class GroupNormalizationFusionTestBase { std::shared_ptr opt_instance_norm_gamma_multiply = mvn; if (instanceNormGammaPresent) { auto instance_norm_gamma_const = - op::v0::Constant::create(T_elem_t, instanceNormGammaShape, instanceNormGammaVals); + op::v0::Constant::create(T_elem, instanceNormGammaShape, instanceNormGammaVals); opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); } std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; if (instanceNormBetaPresent) { auto instance_norm_beta_const = - op::v0::Constant::create(T_elem_t, instanceNormBetaShape, instanceNormBetaVals); + op::v0::Constant::create(T_elem, instanceNormBetaShape, instanceNormBetaVals); opt_instance_norm_beta_add = std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); } @@ -115,24 +115,24 @@ class GroupNormalizationFusionTestBase { auto post_instance_norm_reshape = std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); - auto group_norm_gamma_const = op::v0::Constant::create(T_elem_t, groupNormGammaShape, groupNormGammaVals); + auto group_norm_gamma_const = op::v0::Constant::create(T_elem, groupNormGammaShape, groupNormGammaVals); auto group_norm_gamma_multiply = std::make_shared(post_instance_norm_reshape, group_norm_gamma_const); - auto group_norm_beta_const = op::v0::Constant::create(T_elem_t, groupNormBetaShape, groupNormBetaVals); + auto group_norm_beta_const = op::v0::Constant::create(T_elem, groupNormBetaShape, groupNormBetaVals); auto group_norm_beta_add = std::make_shared(group_norm_gamma_multiply, group_norm_beta_const); return std::make_shared(NodeVector{group_norm_beta_add}, ParameterVector{input}); } }; -template +template class GroupNormalizationFusionSubgraphTestsF - : public GroupNormalizationFusionTestBase, + : public GroupNormalizationFusionTestBase, public ov::test::SubgraphBaseTest, public testing::WithParamInterface { public: - static constexpr element::Type T_elem_t = T_elem; + static constexpr element::Type T_elem = T_elem_t; static std::string getTestCaseName( const testing::TestParamInfo& obj) { const auto& params = obj.param; @@ -222,8 +222,8 @@ class GroupNormalizationFusionSubgraphTestsF this->instanceNormGammaPresent = (this->instanceNormGammaShape != Shape{}); this->instanceNormBetaPresent = (this->instanceNormBetaShape != Shape{}); - inType = T_elem_t; - outType = T_elem_t; + inType = T_elem; + outType = T_elem; targetDevice = targetDeviceName; configuration = targetConfiguration; @@ -250,13 +250,13 @@ class GroupNormalizationFusionSubgraphTestsF void configure_device() { if (targetConfiguration.count(ov::hint::inference_precision.name()) <= 0) { - targetConfiguration.insert({ov::hint::inference_precision.name(), T_elem_t}); + targetConfiguration.insert({ov::hint::inference_precision.name(), T_elem}); } } void configure_ref_device() { if (refConfiguration.count(ov::hint::inference_precision.name()) <= 0) { - refConfiguration.insert({ov::hint::inference_precision.name(), T_elem_t}); + refConfiguration.insert({ov::hint::inference_precision.name(), T_elem}); } } From c85c28dc600e6af4650c7a86444b7b49838d7221 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 15:04:10 +0100 Subject: [PATCH 37/45] Fix comparison of integer expressions of different signedness in GroupNormalizationFusion pass and tests --- .../group_normalization_fusion.cpp | 20 ++++++++++--------- .../group_normalization_fusion_tests.cpp | 4 ++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index d3bcd456e1281e..92a2482e32bcab 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -27,15 +27,17 @@ bool pre_mvn_shape_vals_correct(const std::shared_ptr& pre bool res = true; std::vector pre_mvn_shape_vals = pre_mvn_shape_const->get_vector(); if (input_ps[0].is_dynamic()) { - if (pre_mvn_shape_vals[0] != 0) + if (static_cast(pre_mvn_shape_vals[0]) != 0ll) res = false; } else { - if ((pre_mvn_shape_vals[0] != 0) && (pre_mvn_shape_vals[0] != input_ps[0].get_max_length())) + if ((static_cast(pre_mvn_shape_vals[0]) != 0ll) && + (static_cast(pre_mvn_shape_vals[0]) != static_cast(input_ps[0].get_max_length()))) res = false; } - if ((pre_mvn_shape_vals[1] != 0) && (pre_mvn_shape_vals[1] != num_groups)) + if ((static_cast(pre_mvn_shape_vals[1]) != 0ll) && + (static_cast(pre_mvn_shape_vals[1]) != static_cast(num_groups))) res = false; - if (pre_mvn_shape_vals[2] != -1) + if (static_cast(pre_mvn_shape_vals[2]) != -1ll) res = false; return res; } @@ -44,7 +46,7 @@ template ::value, bool> = true> bool mvn_reduction_axes_correct(const std::shared_ptr& mvn_reduction_axes_const) { bool res = true; std::vector mvn_reduce_axes = mvn_reduction_axes_const->get_vector(); - if ((mvn_reduce_axes[0] != 2) && (mvn_reduce_axes[0] != -1)) + if ((static_cast(mvn_reduce_axes[0]) != 2ll) && (static_cast(mvn_reduce_axes[0]) != -1ll)) return false; return res; } @@ -101,8 +103,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& pre_mvn_reshape_out_ps = pattern_map.at(pre_mvn_reshape_m).get_partial_shape(); - const auto& num_channels = input_ps[1].get_max_length(); - const auto& num_groups = pre_mvn_reshape_out_ps[1].get_max_length(); + const size_t num_channels = static_cast(input_ps[1].get_max_length()); + const size_t num_groups = static_cast(pre_mvn_reshape_out_ps[1].get_max_length()); // we expect to reshape input in a way that would merge all spatial dimensions // but leave batch and channel dimensions untouched @@ -206,7 +208,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (ov::shape_size(group_norm_beta.get_shape()) != num_channels) return false; - auto expected_param_shape = ov::PartialShape({num_channels}); + auto expected_param_shape = ov::PartialShape({static_cast(num_channels)}); std::shared_ptr group_norm_gamma_1d_m = std::make_shared(group_norm_gamma); const auto& group_norm_gamma_1d_out_ps = group_norm_gamma_1d_m->get_output_partial_shape(0); @@ -222,7 +224,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto gather_axis_const_m = op::v0::Constant::create(element::i64, Shape{1}, {0}); auto gather_indices_vals = std::vector(); - for (auto i = 0; i < num_groups; i++) + for (auto i = 0ull; i < num_groups; i++) gather_indices_vals.insert(gather_indices_vals.end(), channels_to_groups_ratio, i); auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{static_cast(num_channels)}, gather_indices_vals); diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index d84aad51d9a3e8..101ee04b758430 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -174,7 +174,7 @@ class GroupNormalizationFusionTransformationTestsF auto group_norm_beta_corr_vals = this->groupNormBetaVals; if (this->instanceNormBetaPresent) - for (auto i = 0; i < group_norm_beta_corr_vals.size(); i++) + for (auto i = 0ull; i < group_norm_beta_corr_vals.size(); i++) group_norm_beta_corr_vals[i] = this->groupNormGammaVals[i] * this->instanceNormBetaVals[i / (this->numChannels / this->numGroups)] + @@ -183,7 +183,7 @@ class GroupNormalizationFusionTransformationTestsF auto group_norm_gamma_corr_vals = this->groupNormGammaVals; if (this->instanceNormGammaPresent) - for (auto i = 0; i < group_norm_gamma_corr_vals.size(); i++) + for (auto i = 0ull; i < group_norm_gamma_corr_vals.size(); i++) group_norm_gamma_corr_vals[i] = this->groupNormGammaVals[i] * this->instanceNormGammaVals[i / (this->numChannels / this->numGroups)]; auto group_norm_gamma_1d = From d98d69b99e9459881c72c97c5d5d9d6f151bf17d Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Wed, 5 Feb 2025 15:25:09 +0100 Subject: [PATCH 38/45] Override init_thresholds() in MHA shared functional test class --- src/tests/functional/plugin/shared/include/snippets/mha.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index 34cb4d452bfb15..e024357fe0e2c3 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -42,9 +42,9 @@ class MHABase : virtual public SnippetsTestsCommon { void SetUp() override; void compile_model() override; void generate_inputs(const std::vector& targetInputStaticShapes) override; + void init_thresholds() override; virtual std::shared_ptr get_subgraph() const = 0; virtual void init_params(std::vector& input_shapes, ov::element::Type& prc, ov::AnyMap& additional_config) = 0; - virtual void init_thresholds(); size_t m_thread_count; std::vector m_input_types; From a0b809b36d8bd1ef426a0d19f67803514f01ec50 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 13 Feb 2025 13:46:09 +0100 Subject: [PATCH 39/45] Simplify pre-MVN shape and MVN reduction axes checks to avoid switch statements in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 69 +++---------------- 1 file changed, 11 insertions(+), 58 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 92a2482e32bcab..5a8b3396f63177 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -20,33 +20,28 @@ using namespace ov::pass::pattern; -template ::value, bool> = true> -bool pre_mvn_shape_vals_correct(const std::shared_ptr& pre_mvn_shape_const, +bool pre_mvn_shape_vals_correct(const std::vector& pre_mvn_shape_vals, const ov::PartialShape& input_ps, const ov::Dimension::value_type num_groups) { bool res = true; - std::vector pre_mvn_shape_vals = pre_mvn_shape_const->get_vector(); if (input_ps[0].is_dynamic()) { - if (static_cast(pre_mvn_shape_vals[0]) != 0ll) + if (pre_mvn_shape_vals[0] != 0ll) res = false; } else { - if ((static_cast(pre_mvn_shape_vals[0]) != 0ll) && - (static_cast(pre_mvn_shape_vals[0]) != static_cast(input_ps[0].get_max_length()))) + if ((pre_mvn_shape_vals[0] != 0ll) && + (pre_mvn_shape_vals[0] != static_cast(input_ps[0].get_max_length()))) res = false; } - if ((static_cast(pre_mvn_shape_vals[1]) != 0ll) && - (static_cast(pre_mvn_shape_vals[1]) != static_cast(num_groups))) + if ((pre_mvn_shape_vals[1] != 0ll) && (pre_mvn_shape_vals[1] != static_cast(num_groups))) res = false; - if (static_cast(pre_mvn_shape_vals[2]) != -1ll) + if (pre_mvn_shape_vals[2] != -1ll) res = false; return res; } -template ::value, bool> = true> -bool mvn_reduction_axes_correct(const std::shared_ptr& mvn_reduction_axes_const) { +bool mvn_reduction_axes_correct(const std::vector& mvn_reduction_axes) { bool res = true; - std::vector mvn_reduce_axes = mvn_reduction_axes_const->get_vector(); - if ((static_cast(mvn_reduce_axes[0]) != 2ll) && (static_cast(mvn_reduce_axes[0]) != -1ll)) + if ((mvn_reduction_axes[0] != 2ll) && (mvn_reduction_axes[0] != -1ll)) return false; return res; } @@ -114,42 +109,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& pre_mvn_shape_out_ps = pre_mvn_shape.get_shape(); if (pre_mvn_shape_out_ps[0] != 3) return false; - switch (pre_mvn_shape_const->get_element_type()) { - case ov::element::i8: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::i16: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::i32: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::i64: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::u8: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::u16: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::u32: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - case ov::element::u64: - if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const, input_ps, num_groups)) - return false; - break; - default: + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const->cast_vector(), input_ps, num_groups)) return false; - } // number of channels has to be divisible by number of groups if (num_channels % num_groups != 0) @@ -168,16 +129,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& mvn_reduction_axes_out_shape = mvn_reduction_axes.get_shape(); if (mvn_reduction_axes_out_shape[0] != 1) return false; - switch (mvn_reduction_axes_const->get_element_type()) { - case ov::element::i32: - mvn_reduction_axes_correct(mvn_reduction_axes_const); - break; - case ov::element::i64: - mvn_reduction_axes_correct(mvn_reduction_axes_const); - break; - default: - break; - } + if (!mvn_reduction_axes_correct(mvn_reduction_axes_const->cast_vector())) + return false; const auto& post_instance_norm_reshape_out_ps = pattern_map.at(post_instance_norm_reshape_m).get_partial_shape(); From c5e080f5bf01e8afa4b36d88de197eeae11d5a62 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 13 Feb 2025 13:49:50 +0100 Subject: [PATCH 40/45] Remove comments in self-explanatory parts of code --- .../group_normalization_fusion.cpp | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 5a8b3396f63177..23ac9d34631bca 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -139,25 +139,14 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { return false; const auto& group_norm_gamma = pattern_map.at(group_norm_gamma_m); - // group_norm_gamma has to share the same data type as - // pattern input if (group_norm_gamma.get_element_type() != T) return false; - - // number of elements in group_norm_gamma must be equal to - // number of channels if (ov::shape_size(group_norm_gamma.get_shape()) != num_channels) return false; const auto& group_norm_beta = pattern_map.at(group_norm_beta_m); - - // group_norm_beta has to share the same data type as - // pattern input if (group_norm_beta.get_element_type() != T) return false; - - // number of elements in group_norm_beta must be equal to - // number of channels if (ov::shape_size(group_norm_beta.get_shape()) != num_channels) return false; @@ -185,14 +174,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { std::shared_ptr instance_norm_beta_1d_m = nullptr; if (pattern_map.count(instance_norm_beta_m) > 0) { const auto& instance_norm_beta = pattern_map.at(instance_norm_beta_m); - - // instance_norm_beta has to share the same data type as - // pattern input if (instance_norm_beta.get_element_type() != T) return false; - - // number of elements in instance_norm_beta must be equal to - // number of groups if (ov::shape_size(instance_norm_beta.get_shape()) != num_groups) return false; @@ -223,14 +206,8 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { if (pattern_map.count(instance_norm_gamma_m) > 0) { const auto& instance_norm_gamma = pattern_map.at(instance_norm_gamma_m); - - // instance_norm_gamma has to share the same data type as - // pattern input if (instance_norm_gamma.get_element_type() != T) return false; - - // number of elements in instance_norm_gamma must be equal to - // number of groups if (ov::shape_size(instance_norm_gamma.get_shape()) != num_groups) return false; From a7fef61eff26b8e030835dfddb7e71479a61b6ca Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 13 Feb 2025 13:53:16 +0100 Subject: [PATCH 41/45] Remove unnecessary cast in GroupNormalizationFusion pass --- .../common_optimizations/group_normalization_fusion.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index 23ac9d34631bca..dd9b2aefa01728 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -168,8 +168,7 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { auto gather_indices_vals = std::vector(); for (auto i = 0ull; i < num_groups; i++) gather_indices_vals.insert(gather_indices_vals.end(), channels_to_groups_ratio, i); - auto gather_indices_const_m = - op::v0::Constant::create(element::i64, Shape{static_cast(num_channels)}, gather_indices_vals); + auto gather_indices_const_m = op::v0::Constant::create(element::i64, Shape{num_channels}, gather_indices_vals); std::shared_ptr instance_norm_beta_1d_m = nullptr; if (pattern_map.count(instance_norm_beta_m) > 0) { From ba5438d296c1e1f055bdd6a48426c3981ecc8df1 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 13 Feb 2025 15:55:30 +0100 Subject: [PATCH 42/45] Use lambas for pre-MVN shape and MVN reduction axes checks in GroupNormalizationFusion pass --- .../group_normalization_fusion.cpp | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp index dd9b2aefa01728..88610d45dcdd70 100644 --- a/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp @@ -20,32 +20,6 @@ using namespace ov::pass::pattern; -bool pre_mvn_shape_vals_correct(const std::vector& pre_mvn_shape_vals, - const ov::PartialShape& input_ps, - const ov::Dimension::value_type num_groups) { - bool res = true; - if (input_ps[0].is_dynamic()) { - if (pre_mvn_shape_vals[0] != 0ll) - res = false; - } else { - if ((pre_mvn_shape_vals[0] != 0ll) && - (pre_mvn_shape_vals[0] != static_cast(input_ps[0].get_max_length()))) - res = false; - } - if ((pre_mvn_shape_vals[1] != 0ll) && (pre_mvn_shape_vals[1] != static_cast(num_groups))) - res = false; - if (pre_mvn_shape_vals[2] != -1ll) - res = false; - return res; -} - -bool mvn_reduction_axes_correct(const std::vector& mvn_reduction_axes) { - bool res = true; - if ((mvn_reduction_axes[0] != 2ll) && (mvn_reduction_axes[0] != -1ll)) - return false; - return res; -} - ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { MATCHER_SCOPE(GroupNormalizationFusion); @@ -109,6 +83,26 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& pre_mvn_shape_out_ps = pre_mvn_shape.get_shape(); if (pre_mvn_shape_out_ps[0] != 3) return false; + + auto pre_mvn_shape_vals_correct = [](const std::vector& pre_mvn_shape_vals, + const ov::PartialShape& input_ps, + const ov::Dimension::value_type num_groups) -> bool { + bool res = true; + if (input_ps[0].is_dynamic()) { + if (pre_mvn_shape_vals[0] != 0ll) + res = false; + } else { + if ((pre_mvn_shape_vals[0] != 0ll) && + (pre_mvn_shape_vals[0] != static_cast(input_ps[0].get_max_length()))) + res = false; + } + if ((pre_mvn_shape_vals[1] != 0ll) && (pre_mvn_shape_vals[1] != static_cast(num_groups))) + res = false; + if (pre_mvn_shape_vals[2] != -1ll) + res = false; + return res; + }; + if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const->cast_vector(), input_ps, num_groups)) return false; @@ -129,6 +123,14 @@ ov::pass::GroupNormalizationFusion::GroupNormalizationFusion() { const auto& mvn_reduction_axes_out_shape = mvn_reduction_axes.get_shape(); if (mvn_reduction_axes_out_shape[0] != 1) return false; + + auto mvn_reduction_axes_correct = [](const std::vector& mvn_reduction_axes) -> bool { + bool res = true; + if ((mvn_reduction_axes[0] != 2ll) && (mvn_reduction_axes[0] != -1ll)) + return false; + return res; + }; + if (!mvn_reduction_axes_correct(mvn_reduction_axes_const->cast_vector())) return false; From bf73af4200c4ba907a26d16df227c0a67946e1c0 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Thu, 13 Feb 2025 16:10:46 +0100 Subject: [PATCH 43/45] Remove comment describing what GroupNormalizationFusion does from GPU plugin transformation pipeline --- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 510bd91cb71657..b513ec62ccfd40 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -408,11 +408,7 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // fuse softmax, MVN patterns, so that they will not be marked as precision sensitive in ConvertPrecision manager.register_pass(); manager.register_pass(); - // fuse following ops into GroupNormalization: - // group_norm_gamma * (instance_norm_gamma * MVN(x) + instance_norm_beta) + group_norm_beta - // note that instance norm related parameters are optional: - // - instance_norm_gamma is assumed to be filled with ones if not present in the graph - // - instance_norm_beta is assumed to be filled with zeros if not present in the graph + // GroupNormalizationFusion can potentially benefit from MVNFusion manager.register_pass(); // decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision manager.register_pass(); From 14866b287c76df236ac0829dcbf2d33130cfb0b2 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Fri, 14 Feb 2025 00:20:36 +0100 Subject: [PATCH 44/45] Remove redundant ov:: namespace prefixes in GroupNormalizationFusion test classes --- .../group_normalization_fusion_tests.cpp | 162 +++++++++--------- .../subgraph/group_normalization_fusion.hpp | 117 +++++++------ 2 files changed, 139 insertions(+), 140 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index 101ee04b758430..fa760fa7f39a64 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -22,7 +22,7 @@ using GroupNormalizationFusionSubgraphTestValues = template class GroupNormalizationFusionTransformationTestsF - : public ov::test::GroupNormalizationFusionTestBase, + : public test::GroupNormalizationFusionTestBase, public testing::TestWithParam { public: static constexpr element::Type T_elem = T_elem_t; @@ -41,11 +41,11 @@ class GroupNormalizationFusionTransformationTestsF std::ostringstream results; results << "T=" << T_elem_t << "_"; - results << "Input=" << ov::test::utils::partialShape2str({data_shape}) << "_"; - results << "InstNormGamma=" << ov::test::utils::partialShape2str({instance_norm_gamma_shape}) << "_"; - results << "InstNormBeta=" << ov::test::utils::partialShape2str({instance_norm_beta_shape}) << "_"; - results << "GroupNormGamma=" << ov::test::utils::partialShape2str({group_norm_gamma_shape}) << "_"; - results << "GroupNormBeta=" << ov::test::utils::partialShape2str({group_norm_beta_shape}) << "_"; + results << "Input=" << test::utils::partialShape2str({data_shape}) << "_"; + results << "InstNormGamma=" << test::utils::partialShape2str({instance_norm_gamma_shape}) << "_"; + results << "InstNormBeta=" << test::utils::partialShape2str({instance_norm_beta_shape}) << "_"; + results << "GroupNormGamma=" << test::utils::partialShape2str({group_norm_gamma_shape}) << "_"; + results << "GroupNormBeta=" << test::utils::partialShape2str({group_norm_beta_shape}) << "_"; results << "NumGroups=" << num_groups << "_"; results << "Epsilon=" << epsilon << "_"; results << "PositiveTest=" << std::boolalpha << positive_test << "_"; @@ -58,16 +58,16 @@ class GroupNormalizationFusionTransformationTestsF this->generate_weights_init_values(); model = this->create_model(); - manager = ov::pass::Manager(); - manager.register_pass(); - manager.register_pass(); + manager = pass::Manager(); + manager.register_pass(); + manager.register_pass(); OV_ASSERT_NO_THROW(manager.run_passes(model)); if (positiveTest) { model_ref = create_ref_model(); - manager_ref = ov::pass::Manager(); - manager_ref.register_pass(); + manager_ref = pass::Manager(); + manager_ref.register_pass(); OV_ASSERT_NO_THROW(manager_ref.run_passes(model_ref)); const auto& f_parameters = model->get_parameters(); @@ -96,8 +96,8 @@ class GroupNormalizationFusionTransformationTestsF const auto& gn_node = f_results[0]->get_input_node_shared_ptr(0); const auto& gn_ref_node = f_ref_results[0]->get_input_node_shared_ptr(0); - ASSERT_TRUE(ov::is_type(gn_node)); - ASSERT_TRUE(ov::is_type(gn_ref_node)); + ASSERT_TRUE(is_type(gn_node)); + ASSERT_TRUE(is_type(gn_ref_node)); ASSERT_EQ(gn_node->inputs().size(), gn_ref_node->inputs().size()); ASSERT_EQ(gn_node->inputs().size(), 3); ASSERT_EQ(gn_node->get_input_partial_shape(0), gn_ref_node->get_input_partial_shape(0)); @@ -107,23 +107,23 @@ class GroupNormalizationFusionTransformationTestsF ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), shape_size(gn_ref_node->get_input_shape(2))); ASSERT_EQ(shape_size(gn_node->get_input_shape(2)), this->numChannels); - const auto& gn_node_casted = ov::as_type_ptr(gn_node); - const auto& gn_ref_node_casted = ov::as_type_ptr(gn_ref_node); + const auto& gn_node_casted = as_type_ptr(gn_node); + const auto& gn_ref_node_casted = as_type_ptr(gn_ref_node); ASSERT_EQ(gn_node_casted->get_epsilon(), gn_ref_node_casted->get_epsilon()); ASSERT_EQ(gn_node_casted->get_epsilon(), this->epsilon); ASSERT_EQ(gn_node_casted->get_num_groups(), gn_ref_node_casted->get_num_groups()); ASSERT_EQ(gn_node_casted->get_num_groups(), this->numGroups); } else { - ASSERT_EQ(count_ops_of_type(model), 0); + ASSERT_EQ(count_ops_of_type(model), 0); } } protected: bool positiveTest; - ov::pass::Manager manager; - ov::pass::Manager manager_ref; - std::shared_ptr model; - std::shared_ptr model_ref; + pass::Manager manager; + pass::Manager manager_ref; + std::shared_ptr model; + std::shared_ptr model_ref; void read_test_parameters() override { const auto& params = GetParam(); @@ -169,8 +169,8 @@ class GroupNormalizationFusionTransformationTestsF } } - std::shared_ptr create_ref_model() { - auto input = std::make_shared(T_elem, this->dataShape); + std::shared_ptr create_ref_model() { + auto input = std::make_shared(T_elem, this->dataShape); auto group_norm_beta_corr_vals = this->groupNormBetaVals; if (this->instanceNormBetaPresent) @@ -189,11 +189,11 @@ class GroupNormalizationFusionTransformationTestsF auto group_norm_gamma_1d = op::v0::Constant::create(T_elem, Shape{this->numChannels}, group_norm_gamma_corr_vals); - auto group_norm = std::make_shared(input, - group_norm_gamma_1d, - group_norm_beta_1d, - this->numGroups, - this->epsilon); + auto group_norm = std::make_shared(input, + group_norm_gamma_1d, + group_norm_beta_1d, + this->numGroups, + this->epsilon); return std::make_shared(NodeVector{group_norm}, ParameterVector{input}); } @@ -293,7 +293,7 @@ TEST_P(GroupNormalizationFusionTransformationTestsF_f8e8m0, GroupNormalizationFu using GroupNormalizationFusionSubgraphTestAdditionalValues = std::tuple; // whether it's a positive test that should run reference model or a negative test -std::vector valid_vals = { +std::vector valid_vals = { std::make_tuple(PartialShape{1, 320}, Shape{}, Shape{}, Shape{320}, Shape{320}, 1, 1e-5f), std::make_tuple(PartialShape{1, 320, 2, 2}, Shape{1, 1, 1}, @@ -348,7 +348,7 @@ std::vector valid_vals = { 64, 1e-6f)}; -std::vector invalid_vals = { +std::vector invalid_vals = { std::make_tuple(PartialShape{1, 320}, Shape{}, Shape{}, Shape{}, Shape{}, 1, 1e-5f), std::make_tuple(PartialShape{1, 320, 2, 2}, Shape{1, 1, 1}, @@ -383,156 +383,156 @@ std::vector invalid_vals = { INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationPositiveTests_f32, GroupNormalizationFusionTransformationTestsF_f32, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(true))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(true))), GroupNormalizationFusionTransformationTestsF_f32::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationPositiveTests_f16, GroupNormalizationFusionTransformationTestsF_f16, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(true))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(true))), GroupNormalizationFusionTransformationTestsF_f16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationPositiveTests_bf16, GroupNormalizationFusionTransformationTestsF_bf16, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(true))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(true))), GroupNormalizationFusionTransformationTestsF_bf16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTests_f32, GroupNormalizationFusionTransformationTestsF_f32, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f32::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTests_f16, GroupNormalizationFusionTransformationTestsF_f16, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTests_bf16, GroupNormalizationFusionTransformationTestsF_bf16, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_bf16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u8, GroupNormalizationFusionTransformationTestsF_u8, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u8::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u16, GroupNormalizationFusionTransformationTestsF_u16, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u32, GroupNormalizationFusionTransformationTestsF_u32, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u32::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_u64, GroupNormalizationFusionTransformationTestsF_u64, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u64::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_i8, GroupNormalizationFusionTransformationTestsF_i8, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_i8::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_i16, GroupNormalizationFusionTransformationTestsF_i16, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_i16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_i32, GroupNormalizationFusionTransformationTestsF_i32, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_i32::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_f8e5m2, GroupNormalizationFusionTransformationTestsF_f8e5m2, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f8e5m2::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_f4e2m1, GroupNormalizationFusionTransformationTestsF_f4e2m1, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f4e2m1::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsValidVals_f8e8m0, GroupNormalizationFusionTransformationTestsF_f8e8m0, - ValuesIn(ov::test::expand_vals(valid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(valid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f8e8m0::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u8, GroupNormalizationFusionTransformationTestsF_u8, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u8::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u16, GroupNormalizationFusionTransformationTestsF_u16, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u32, GroupNormalizationFusionTransformationTestsF_u32, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u32::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_u64, GroupNormalizationFusionTransformationTestsF_u64, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_u64::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_i8, GroupNormalizationFusionTransformationTestsF_i8, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_i8::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_i16, GroupNormalizationFusionTransformationTestsF_i16, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_i16::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_i32, GroupNormalizationFusionTransformationTestsF_i32, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_i32::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInalidVals_f8e5m2, GroupNormalizationFusionTransformationTestsF_f8e5m2, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f8e5m2::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_f4e2m1, GroupNormalizationFusionTransformationTestsF_f4e2m1, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f4e2m1::getTestCaseName); INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionTransformationNegativeTestsInvalidVals_f8e8m0, GroupNormalizationFusionTransformationTestsF_f8e8m0, - ValuesIn(ov::test::expand_vals(invalid_vals, - GroupNormalizationFusionSubgraphTestAdditionalValues(false))), + ValuesIn(test::expand_vals(invalid_vals, + GroupNormalizationFusionSubgraphTestAdditionalValues(false))), GroupNormalizationFusionTransformationTestsF_f8e8m0::getTestCaseName); \ No newline at end of file diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index c89e8e366ddb6a..bf017b129a388e 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -17,11 +17,11 @@ namespace ov { namespace test { using GroupNormalizationFusionTestBaseValues = - std::tuple; // epsilon @@ -35,9 +35,9 @@ using GroupNormalizationFusionTransformationsTestValues = float, // epsilon bool, // whether it's a positive test that should run reference model or a negative test std::string, // taget device name - ov::AnyMap, // taget device properties + AnyMap, // taget device properties std::string, // reference device name - ov::AnyMap>; // reference device properties + AnyMap>; // reference device properties template std::vector> expand_vals(std::vector> old_vals, @@ -54,7 +54,7 @@ template class GroupNormalizationFusionTestBase { public: static constexpr element::Type T_elem = T_elem_t; - typedef typename ov::element_type_traits::value_type T_store_t; + typedef typename element_type_traits::value_type T_store_t; protected: size_t numChannels; @@ -85,11 +85,11 @@ class GroupNormalizationFusionTestBase { groupNormBetaVals = test::utils::generateVector(shape_size(groupNormBetaShape), 10, 1, 4); } - std::shared_ptr create_model() { + std::shared_ptr create_model() { auto input = std::make_shared(T_elem, dataShape); auto pre_mvn_shape_const = op::v0::Constant::create(element::i64, Shape{3}, {0, static_cast(numGroups), -1}); - auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); + auto pre_mvn_reshape = std::make_shared(input, pre_mvn_shape_const, true); auto mvn_axes_const = op::v0::Constant::create(element::i64, Shape{1}, {2}); auto mvn = @@ -102,15 +102,15 @@ class GroupNormalizationFusionTestBase { opt_instance_norm_gamma_multiply = std::make_shared(mvn, instance_norm_gamma_const); } - std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; + std::shared_ptr opt_instance_norm_beta_add = opt_instance_norm_gamma_multiply; if (instanceNormBetaPresent) { auto instance_norm_beta_const = op::v0::Constant::create(T_elem, instanceNormBetaShape, instanceNormBetaVals); opt_instance_norm_beta_add = - std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); + std::make_shared(opt_instance_norm_gamma_multiply, instance_norm_beta_const); } - auto post_instance_norm_shape = std::make_shared(input); + auto post_instance_norm_shape = std::make_shared(input); auto post_instance_norm_reshape = std::make_shared(opt_instance_norm_beta_add, post_instance_norm_shape, true); @@ -129,7 +129,7 @@ class GroupNormalizationFusionTestBase { template class GroupNormalizationFusionSubgraphTestsF : public GroupNormalizationFusionTestBase, - public ov::test::SubgraphBaseTest, + public test::SubgraphBaseTest, public testing::WithParamInterface { public: static constexpr element::Type T_elem = T_elem_t; @@ -153,11 +153,11 @@ class GroupNormalizationFusionSubgraphTestsF std::ostringstream results; results << "T=" << T_elem_t << "_"; - results << "Input=" << ov::test::utils::partialShape2str({data_shape}) << "_"; - results << "InstNormGamma=" << ov::test::utils::partialShape2str({instance_norm_gamma_shape}) << "_"; - results << "InstNormBeta=" << ov::test::utils::partialShape2str({instance_norm_beta_shape}) << "_"; - results << "GroupNormGamma=" << ov::test::utils::partialShape2str({group_norm_gamma_shape}) << "_"; - results << "GroupNormBeta=" << ov::test::utils::partialShape2str({group_norm_beta_shape}) << "_"; + results << "Input=" << test::utils::partialShape2str({data_shape}) << "_"; + results << "InstNormGamma=" << test::utils::partialShape2str({instance_norm_gamma_shape}) << "_"; + results << "InstNormBeta=" << test::utils::partialShape2str({instance_norm_beta_shape}) << "_"; + results << "GroupNormGamma=" << test::utils::partialShape2str({group_norm_gamma_shape}) << "_"; + results << "GroupNormBeta=" << test::utils::partialShape2str({group_norm_beta_shape}) << "_"; results << "NumGroups=" << num_groups << "_"; results << "Epsilon=" << epsilon << "_"; results << "PositiveTest=" << std::boolalpha << positive_test << "_"; @@ -183,13 +183,13 @@ class GroupNormalizationFusionSubgraphTestsF protected: bool positiveTest; std::string targetDeviceName; - ov::AnyMap targetConfiguration; + AnyMap targetConfiguration; std::string refDevice; - ov::AnyMap refConfiguration; + AnyMap refConfiguration; ElementType refInferencePrecision; - ov::CompiledModel compiledRefModel; - ov::InferRequest refInferRequest; + CompiledModel compiledRefModel; + InferRequest refInferRequest; void TearDown() override { SubgraphBaseTest::TearDown(); @@ -249,24 +249,24 @@ class GroupNormalizationFusionSubgraphTestsF } void configure_device() { - if (targetConfiguration.count(ov::hint::inference_precision.name()) <= 0) { - targetConfiguration.insert({ov::hint::inference_precision.name(), T_elem}); + if (targetConfiguration.count(hint::inference_precision.name()) <= 0) { + targetConfiguration.insert({hint::inference_precision.name(), T_elem}); } } void configure_ref_device() { - if (refConfiguration.count(ov::hint::inference_precision.name()) <= 0) { - refConfiguration.insert({ov::hint::inference_precision.name(), T_elem}); + if (refConfiguration.count(hint::inference_precision.name()) <= 0) { + refConfiguration.insert({hint::inference_precision.name(), T_elem}); } } void configure_ref_model() { // configure input precision - ov::preprocess::PrePostProcessor p(functionRefs); + preprocess::PrePostProcessor p(functionRefs); { auto& params = functionRefs->get_parameters(); for (size_t i = 0; i < params.size(); i++) { - if (inType != ov::element::Type_t::undefined) { + if (inType != element::Type_t::undefined) { p.input(i).tensor().set_element_type(inType); } } @@ -276,7 +276,7 @@ class GroupNormalizationFusionSubgraphTestsF { auto results = functionRefs->get_results(); for (size_t i = 0; i < results.size(); i++) { - if (outType != ov::element::Type_t::undefined) { + if (outType != element::Type_t::undefined) { p.output(i).tensor().set_element_type(outType); } } @@ -302,7 +302,7 @@ class GroupNormalizationFusionSubgraphTestsF << duration.count() << "s" << std::endl; } try { - refInferencePrecision = core->get_property(refDevice, ov::hint::inference_precision); + refInferencePrecision = core->get_property(refDevice, hint::inference_precision); } catch (std::exception& e) { std::cout << "[ WARNING ] Impossible to get Inference Precision with exception: " << e.what() << std::endl; } @@ -317,7 +317,7 @@ class GroupNormalizationFusionSubgraphTestsF } } - void infer_ref(const std::map, ov::Tensor>& inputs_ref) { + void infer_ref(const std::map, Tensor>& inputs_ref) { refInferRequest = compiledRefModel.create_infer_request(); for (const auto& input : inputs_ref) { refInferRequest.set_tensor(input.first, input.second); @@ -325,7 +325,7 @@ class GroupNormalizationFusionSubgraphTestsF refInferRequest.infer(); } - std::vector calculate_refs() override { + std::vector calculate_refs() override { if (is_report_stages) { std::cout << "[ REFERENCE ] `GroupNormalizationFusionSubgraphTestsF::calculate_refs()` is started" << std::endl; @@ -335,13 +335,13 @@ class GroupNormalizationFusionSubgraphTestsF update_ref_model(); match_parameters(function->get_parameters(), functionRefs->get_parameters()); - std::map, ov::Tensor> inputs_ref; + std::map, Tensor> inputs_ref; for (const auto& param : functionRefs->get_parameters()) { inputs_ref[param] = inputs.at(matched_parameters[param]); } infer_ref(inputs_ref); - auto outputs = std::vector{}; + auto outputs = std::vector{}; for (const auto& output : functionRefs->outputs()) { outputs.push_back(refInferRequest.get_tensor(output)); } @@ -355,18 +355,18 @@ class GroupNormalizationFusionSubgraphTestsF return outputs; } - void generate_inputs(const std::vector& targetInputStaticShapes) override { + void generate_inputs(const std::vector& targetInputStaticShapes) override { inputs.clear(); auto itTargetShape = targetInputStaticShapes.begin(); for (const auto& param : function->get_parameters()) { - std::shared_ptr inputNode = param; + std::shared_ptr inputNode = param; for (size_t i = 0; i < param->get_output_size(); i++) { for (const auto& node : param->get_output_target_inputs(i)) { - std::shared_ptr nodePtr = node.get_node()->shared_from_this(); + std::shared_ptr nodePtr = node.get_node()->shared_from_this(); for (size_t port = 0; port < nodePtr->get_input_size(); ++port) { if (nodePtr->get_input_node_ptr(port)->shared_from_this() == inputNode->shared_from_this()) { - const auto& tensor = ov::test::utils::create_and_fill_tensor(inType, *itTargetShape); + const auto& tensor = test::utils::create_and_fill_tensor(inType, *itTargetShape); inputs.insert({param, tensor}); break; } @@ -380,26 +380,25 @@ class GroupNormalizationFusionSubgraphTestsF public: void run() override { is_reported = true; - bool isCurrentTestDisabled = ov::test::utils::current_test_is_disabled(); + bool isCurrentTestDisabled = test::utils::current_test_is_disabled(); - ov::test::utils::PassRate::Statuses status = isCurrentTestDisabled - ? ov::test::utils::PassRate::Statuses::SKIPPED - : ov::test::utils::PassRate::Statuses::CRASHED; + test::utils::PassRate::Statuses status = + isCurrentTestDisabled ? test::utils::PassRate::Statuses::SKIPPED : test::utils::PassRate::Statuses::CRASHED; if (isCurrentTestDisabled) GTEST_SKIP() << "Disabled test due to configuration" << std::endl; // in case of crash jump will be made and work will be continued - auto crashHandler = std::unique_ptr(new ov::test::utils::CrashHandler()); + auto crashHandler = std::unique_ptr(new test::utils::CrashHandler()); // place to jump in case of a crash int jmpRes = 0; #ifdef _WIN32 - jmpRes = setjmp(ov::test::utils::env); + jmpRes = setjmp(test::utils::env); #else - jmpRes = sigsetjmp(ov::test::utils::env, 1); + jmpRes = sigsetjmp(test::utils::env, 1); #endif - if (jmpRes == ov::test::utils::JMP_STATUS::ok) { + if (jmpRes == test::utils::JMP_STATUS::ok) { crashHandler->StartTimer(); std::string errorMessage; try { @@ -408,14 +407,14 @@ class GroupNormalizationFusionSubgraphTestsF functionRefs = this->create_model(); function = functionRefs->clone(); pass::Manager m; - m.register_pass(); + m.register_pass(); OV_ASSERT_NO_THROW(m.run_passes(function)); summary.setDeviceName(targetDevice); summary.updateOPsStats(function, status, rel_influence_coef); if (positiveTest) { - ASSERT_EQ(count_ops_of_type(functionRefs), 0); - ASSERT_EQ(count_ops_of_type(function), 1); + ASSERT_EQ(count_ops_of_type(functionRefs), 0); + ASSERT_EQ(count_ops_of_type(function), 1); if (!function->is_dynamic()) { configure_device(); @@ -432,31 +431,31 @@ class GroupNormalizationFusionSubgraphTestsF } } } else { - ASSERT_EQ(count_ops_of_type(functionRefs), 0); - ASSERT_EQ(count_ops_of_type(function), 0); + ASSERT_EQ(count_ops_of_type(functionRefs), 0); + ASSERT_EQ(count_ops_of_type(function), 0); } - status = ov::test::utils::PassRate::Statuses::PASSED; + status = test::utils::PassRate::Statuses::PASSED; } catch (const std::exception& ex) { if (callback_exception != nullptr) { // exception will be checked by callback. callback_exception(ex); return; } else { - status = ov::test::utils::PassRate::Statuses::FAILED; + status = test::utils::PassRate::Statuses::FAILED; errorMessage = ex.what(); } } catch (...) { - status = ov::test::utils::PassRate::Statuses::FAILED; + status = test::utils::PassRate::Statuses::FAILED; errorMessage = "Unknown failure occurred."; } summary.updateOPsStats(function, status, rel_influence_coef); - if (status != ov::test::utils::PassRate::Statuses::PASSED) { + if (status != test::utils::PassRate::Statuses::PASSED) { GTEST_FATAL_FAILURE_(errorMessage.c_str()); } - } else if (jmpRes == ov::test::utils::JMP_STATUS::anyError) { + } else if (jmpRes == test::utils::JMP_STATUS::anyError) { OPENVINO_THROW("Crash happens"); - } else if (jmpRes == ov::test::utils::JMP_STATUS::alarmErr) { - summary.updateOPsStats(function, ov::test::utils::PassRate::Statuses::HANGED, rel_influence_coef); + } else if (jmpRes == test::utils::JMP_STATUS::alarmErr) { + summary.updateOPsStats(function, test::utils::PassRate::Statuses::HANGED, rel_influence_coef); OPENVINO_THROW("Crash happens"); } } From 065317e3c06ff15f3e90d49b5d1884efdeee7303 Mon Sep 17 00:00:00 2001 From: Jedrzej Hajduczenia Date: Fri, 14 Feb 2025 00:32:14 +0100 Subject: [PATCH 45/45] Use size_t as data type for number of groups in GroupNormalizationFusion tests --- .../common_optimizations/group_normalization_fusion_tests.cpp | 2 +- .../subgraph/group_normalization_fusion.hpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp index fa760fa7f39a64..cbf43eade7013a 100644 --- a/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp +++ b/src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp @@ -16,7 +16,7 @@ using GroupNormalizationFusionSubgraphTestValues = Shape, // shape of optional instance norm beta tensor (or empty shape if not used) Shape, // shape of group norm gamma tensor Shape, // shape of group norm beta tensor - unsigned long long, // number of groups + size_t, // number of groups float, // epsilon bool>; // whether it's a positive test that should run reference model or a negative test diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp index bf017b129a388e..280da83058ce0d 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/group_normalization_fusion.hpp @@ -22,7 +22,7 @@ using GroupNormalizationFusionTestBaseValues = Shape, // shape of optional instance norm beta tensor (or empty shape if not used) Shape, // shape of group norm gamma tensor Shape, // shape of group norm beta tensor - unsigned long long, // number of groups + size_t, // number of groups float>; // epsilon using GroupNormalizationFusionTransformationsTestValues = @@ -31,7 +31,7 @@ using GroupNormalizationFusionTransformationsTestValues = Shape, // shape of optional instance norm beta tensor (or empty shape if not used) Shape, // shape of group norm gamma tensor Shape, // shape of group norm beta tensor - unsigned long long, // number of groups + size_t, // number of groups float, // epsilon bool, // whether it's a positive test that should run reference model or a negative test std::string, // taget device name