-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TRANSFORMATIONS][GPU] Add GroupNormalization fusion to common optimizations #28387
base: master
Are you sure you want to change the base?
[TRANSFORMATIONS][GPU] Add GroupNormalization fusion to common optimizations #28387
Conversation
3a67c81
to
ba87e35
Compare
build_jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm regarding Core part.
src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp
Show resolved
Hide resolved
ba87e35
to
3bd623d
Compare
build_jenkins |
3bd623d
to
d69391b
Compare
d69391b
to
b1ac67d
Compare
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
.../transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp
Outdated
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp
Show resolved
Hide resolved
src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp
Outdated
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/common_optimizations/group_normalization_fusion_tests.cpp
Outdated
Show resolved
Hide resolved
.../transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp
Show resolved
Hide resolved
.../transformations/include/transformations/common_optimizations/group_normalization_fusion.hpp
Outdated
Show resolved
Hide resolved
9cf1017
to
8382572
Compare
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Outdated
Show resolved
Hide resolved
...mmon/transformations/src/transformations/common_optimizations/group_normalization_fusion.cpp
Show resolved
Hide resolved
…lizationFusion tests
…sformations pipeline
…es in GroupNormalizationFusion pass
…unctional subraph test fixture class
…ationFusionTestBase in derived classes' templates
…tional subgraph test
…izationFusion tests
…pNormalizationFusion pass and tests
493e486
to
e73dd5d
Compare
return false; | ||
break; | ||
case ov::element::u64: | ||
if (!pre_mvn_shape_vals_correct<uint64_t>(pre_mvn_shape_const, input_ps, num_groups)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this switch can be removed and replaced with single if:
if (!pre_mvn_shape_vals_correct(pre_mvn_shape_const->cast_vector<int64_t>(), input_ps, num_groups))
const auto& mvn_reduction_axes_out_shape = mvn_reduction_axes.get_shape(); | ||
if (mvn_reduction_axes_out_shape[0] != 1) | ||
return false; | ||
switch (mvn_reduction_axes_const->get_element_type()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
return false; | ||
|
||
// number of elements in group_norm_beta must be equal to | ||
// number of channels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: IMO, comment here and above are not needed as the code is more or less simple
for (auto i = 0ull; i < num_groups; i++) | ||
gather_indices_vals.insert(gather_indices_vals.end(), channels_to_groups_ratio, i); | ||
auto gather_indices_const_m = | ||
op::v0::Constant::create(element::i64, Shape{static_cast<size_t>(num_channels)}, gather_indices_vals); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op::v0::Constant::create(element::i64, Shape{static_cast<size_t>(num_channels)}, gather_indices_vals); | |
op::v0::Constant::create(element::i64, Shape{num_channels}, gather_indices_vals); |
class GroupNormalizationFusionTransformationTestsF_f32 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::f32> {}; | ||
class GroupNormalizationFusionTransformationTestsF_f16 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::f16> {}; | ||
class GroupNormalizationFusionTransformationTestsF_bf16 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::bf16> {}; | ||
class GroupNormalizationFusionTransformationTestsF_u8 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::u8> {}; | ||
class GroupNormalizationFusionTransformationTestsF_u16 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::u16> {}; | ||
class GroupNormalizationFusionTransformationTestsF_u32 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::u32> {}; | ||
class GroupNormalizationFusionTransformationTestsF_u64 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::u64> {}; | ||
class GroupNormalizationFusionTransformationTestsF_i8 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::i8> {}; | ||
class GroupNormalizationFusionTransformationTestsF_i16 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::i16> {}; | ||
class GroupNormalizationFusionTransformationTestsF_i32 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::i32> {}; | ||
class GroupNormalizationFusionTransformationTestsF_i64 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::i64> {}; | ||
class GroupNormalizationFusionTransformationTestsF_f8e4m3 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::f8e4m3> {}; | ||
class GroupNormalizationFusionTransformationTestsF_f8e5m2 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::f8e5m2> {}; | ||
class GroupNormalizationFusionTransformationTestsF_f4e2m1 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::f4e2m1> {}; | ||
class GroupNormalizationFusionTransformationTestsF_f8e8m0 | ||
: public GroupNormalizationFusionTransformationTestsF<element::Type_t::f8e8m0> {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, test cases for integer types shall be removed as spec says that supported tensor types are floating-point for this op
} | ||
|
||
template <element::Type_t T_elem_t> | ||
class GroupNormalizationFusionTestBase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think type template is not needed for this test as it can be just a parameter of the test. As I can see, in most of the cases you actually use element::Type
object as a parameter. The only exception is const data generation, but you can use ov::test::utils::create_and_fill_tensor
method to instead, thus the template arg won't be needed here too. That change will simplify tests instantiations I believe
GroupNormalizationFusionSubgraphTestsF_f4e2m1::getTestCaseName); | ||
|
||
INSTANTIATE_TEST_SUITE_P(GroupNormalizationFusionSubgraphNegativeTestsValidVals_f8e8m0, | ||
GroupNormalizationFusionSubgraphTestsF_f8e8m0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why these test cases work while fp8/fp4/bf16 types are not supported by GPU plugin at all and integer types are not supported by the kernel?
ov::test::utils::DEVICE_TEMPLATE, | ||
{{"DISABLE_TRANSFORMATIONS", true}}))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need these args as tests parametes? I'd expect them to be hardcoded in the test
refInferRequest.infer(); | ||
} | ||
|
||
std::vector<ov::Tensor> calculate_refs() override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between this and base class version? The base one also runs model via template plugin
} | ||
|
||
public: | ||
void run() override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see some small difference for this impl vs base method, but IMO those differences should be handled separately to avoid copy-pasting of key execution logic. For instance, count of group norm ops can be checked in a separate method like this:
TEST_P(GroupNormalizationFusionSubgraphTestsF, CompareWithRefs) {
run();
check_some_plugin_or_test_specific_things_after_base_values_checks(); // add this method to GroupNormalizationFusionSubgraphTestsF
};
Details:
Tickets: