From 637d510c436efc9bf771a7ec480b74ba8bef5dfc Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Wed, 18 Dec 2024 11:06:23 +0000 Subject: [PATCH 01/47] Einsum core improvements Signed-off-by: MATEUSZ MIKOLAJCZYK --- src/core/reference/src/op/einsum.cpp | 25 ++- .../include/einsum_shape_inference.hpp | 21 +- src/core/src/op/einsum.cpp | 5 - src/core/tests/type_prop/einsum.cpp | 208 +++++++++++++++++- 4 files changed, 233 insertions(+), 26 deletions(-) diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index b8b23964346225..3b26491caa7b7b 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -425,7 +425,7 @@ void broadcast_input(ov::TensorVector& inputs, OPENVINO_ASSERT(input_ind < inputs.size()); ov::Tensor& input = inputs[input_ind]; const Shape old_shape = input.get_shape(); - Shape new_shape; + PartialShape new_shape; new_shape.insert(new_shape.end(), new_common_shape.begin(), new_common_shape.end()); if (is_separate_first) { new_shape.insert(new_shape.end(), separate_shape.begin(), separate_shape.end()); @@ -435,15 +435,15 @@ void broadcast_input(ov::TensorVector& inputs, new_shape.insert(new_shape.end(), separate_shape.begin(), separate_shape.end()); } - if (input.get_shape() == new_shape) { + if (input.get_shape() == new_shape.to_shape()) { return; } OPENVINO_ASSERT(old_shape.size() <= new_shape.size()); - auto output = ov::Tensor(input.get_element_type(), new_shape); - std::vector broadcast_axes(old_shape.size()); std::iota(broadcast_axes.begin(), broadcast_axes.end(), new_shape.size() - old_shape.size()); + OPENVINO_ASSERT(PartialShape::broadcast_merge_into(new_shape, old_shape, ov::op::AutoBroadcastType::NUMPY)); + auto output = ov::Tensor(input.get_element_type(), new_shape.to_shape()); reference::broadcast(reinterpret_cast(input.data()), reinterpret_cast(output.data()), @@ -853,8 +853,10 @@ void contract_two_inputs(ov::TensorVector& inputs, PartialShape common_sub_shape1 = compute_sub_shape(input_shape1, common_dims_begin, common_dims_end); PartialShape common_sub_shape2 = compute_sub_shape(input_shape2, common_dims_begin2, common_dims_end2); - Shape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true); - Shape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end); + PartialShape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true); + PartialShape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end); + Shape reduced_sub_shape_prod2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2, true); + Shape reduced_sub_shape2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2); Shape separate1_sub_shape = compute_sub_shape(input_shape1, separate1_dims_begin, separate1_dims_end); Shape separate2_sub_shape = compute_sub_shape(input_shape2, separate2_dims_begin, separate2_dims_end); @@ -862,29 +864,32 @@ void contract_two_inputs(ov::TensorVector& inputs, // in case of ellipsis among the common labels // reference::broadcast() PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, op::AutoBroadcastType::NUMPY); + PartialShape::broadcast_merge_into(reduced_sub_shape, reduced_sub_shape2, op::AutoBroadcastType::NUMPY); + PartialShape::broadcast_merge_into(reduced_sub_shape_prod, reduced_sub_shape_prod2, op::AutoBroadcastType::NUMPY); Shape common_sub_shape = common_sub_shape1.get_shape(); broadcast_input(inputs, input_ind1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape, + reduced_sub_shape.get_shape(), is_separate_first1); broadcast_input(inputs, input_ind2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape, + reduced_sub_shape.get_shape(), is_separate_first2); ov::Tensor matmul_operand1 = reshape_input_for_matmul(input1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape_prod.get_shape(), is_separate_first1); + ov::Tensor matmul_operand2 = reshape_input_for_matmul(input2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape_prod.get_shape(), is_separate_first2); // step 3. apply MatMul operation for formatted inputs diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index eb84482af0f052..5de11922f894a4 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -31,15 +31,19 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) { const auto& pshape = input_shapes[input_idx]; const auto labels = Einsum::extract_labels(input_subscripts[input_idx]); + const auto has_ellipsis = std::any_of(labels.begin(), labels.end(), [](std::string label) { + return label == "..."; + }); if (pshape.rank().is_static()) { size_t input_rank = pshape.size(); // check that a rank is greater or equal to a number of labels // these numbers are always equal if there is no ellipsis in the subscript - NODE_VALIDATION_CHECK(op, - input_rank >= labels.size(), - "Input rank must be greater or equal to a number of labels in the " - "corresponding input subscript."); + NODE_VALIDATION_CHECK( + op, + (input_rank >= (labels.size() - 1) && has_ellipsis) || (input_rank == labels.size() && !has_ellipsis), + "Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript."); for (size_t label_ind = 0, dim_ind = 0; label_ind < labels.size() && dim_ind < input_rank; ++label_ind) { auto const& label = labels[label_ind]; @@ -64,15 +68,20 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s label_to_shape[label] = TRShape{pshape[dim_ind]}; } else { NODE_VALIDATION_CHECK(op, - label_to_shape[label].compatible(TRShape{pshape[label_ind]}), + TRShape::broadcast_merge_into(label_to_shape[label], + TRShape{pshape[dim_ind]}, + op::AutoBroadcastType::NUMPY), "Different input dimensions indicated by the same labels for Einsum " "must be compatible."); - OPENVINO_ASSERT(TRShape::merge_into(label_to_shape[label], TRShape{pshape[dim_ind]})); } ++dim_ind; } } } else { + if (has_ellipsis) { + // Shape has dynamic rank and ellipsis + return {pshape}; + } for (auto const& label : labels) { NODE_VALIDATION_CHECK(op, label != "...", diff --git a/src/core/src/op/einsum.cpp b/src/core/src/op/einsum.cpp index 281dc58d07684e..8c6e6b34040760 100644 --- a/src/core/src/op/einsum.cpp +++ b/src/core/src/op/einsum.cpp @@ -139,11 +139,6 @@ void op::v7::Einsum::parse_equation(const std::string& equation, OPENVINO_ASSERT(is_subscript_correct(output_subscript, output_is_ellipsis_met), "Output subscript of Einsum equation must consist of either only " "alphabetic letters or alphabetic letters with one ellipsis."); - - // if the ellipsis is met in input subscripts, one ellipsis must be in the output subscript - OPENVINO_ASSERT(is_ellipsis_met == output_is_ellipsis_met, - "Output subscript of Einsum equation must contain one ellipsis if " - "ellipsis is met in any input subscript."); } } diff --git a/src/core/tests/type_prop/einsum.cpp b/src/core/tests/type_prop/einsum.cpp index 9fbb04fcc1b610..455c2840f432ac 100644 --- a/src/core/tests/type_prop/einsum.cpp +++ b/src/core/tests/type_prop/einsum.cpp @@ -177,7 +177,7 @@ TEST_F(TypePropEinsumTest, dynamic_shape_diag_extraction) { EXPECT_EQ(o->get_element_type(), et); EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({{3, 5}, 3, 4})); - EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), ElementsAre(symbols[0], symbols[1], symbols[2])); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), ElementsAre(symbols[0], symbols[4], symbols[2])); } TEST_F(TypePropEinsumTest, dynamic_shape_ellipsis) { @@ -372,14 +372,212 @@ TEST_F(TypePropEinsumTest, incorrect_equation_not_broadcastable_shapes) { HasSubstr("Input dimensions labeled with ellipsis for Einsum must be broadcastable.")); } -TEST_F(TypePropEinsumTest, incorrect_equation_missed_ellipsis) { +TEST_F(TypePropEinsumTest, missed_out_ellipsis) { const std::string equation = "a...b,b...->a"; - const auto input_shapes = Shapes{{11, 1, 4, 3}, {3, 11, 7, 5}}; + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_rhs_out_ellipsis) { + const std::string equation = "a...b,b->a"; + + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_lhs_out_ellipsis) { + const std::string equation = "ab,b...->a"; + + const auto input_shapes = Shapes{{11, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_rhs_ellipsis) { + const std::string equation = "a...b,b->a..."; + + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11, 1, 4})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_lhs_ellipsis) { + const std::string equation = "ab,b...->a..."; + + const auto input_shapes = Shapes{{11, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11, 11, 7, 4})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_rhs_ellipsis_implicit) { + const std::string equation = "a...b,b"; + + const auto input_shapes = Shapes{{11, 1, 4, 3}, {3}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({1, 4, 11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, missed_lhs_ellipsis_implicit) { + const std::string equation = "ab,b..."; + + const auto input_shapes = Shapes{{11, 3}, {3, 11, 7, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + const auto o = make_op(inputs, equation); + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), element::f32); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({11, 7, 4, 11})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, all_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes(2, PartialShape::dynamic()); + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, broadcasting_same_symbol_common) { + const std::string equation = "ab,ba->b"; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{7, 5}, {1, 7}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, broadcasting_same_symbol_reduced) { + const std::string equation = "ab,ba->b"; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{1, 5}, {5, 7}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, broadcasting_same_symbol) { + const std::string equation = "ab,ba->b"; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{7, 1}, {5, 1}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, ellipsis_no_dimension) { + const std::string equation = "...ab,ba...->b..."; + constexpr auto et = element::i32; + + auto input_shapes = Shapes{{5, 1}, {5, 5}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, ellipsis_dynamic_shape) { + const std::string equation = "...ab,ba...->b..."; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{-1, 57, 5, 5}, {5, 5}}; + ; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({5, -1, 57})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, input_rank_incompatible_with_equation) { + const std::string equation = "ab,bc->ac"; + + const auto input_shapes = Shapes{{2, 2, 10}, {3, 4}}; + const auto inputs = make_inputs(element::f32, input_shapes); + + OV_EXPECT_THROW(auto o = make_op(inputs, equation), + AssertFailure, + HasSubstr("Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript.")); +} + +TEST_F(TypePropEinsumTest, input_rank_incompatible_with_equation_single_input) { + const std::string equation = "ab->ba"; + + const auto input_shapes = Shapes{{3, 5, 7}}; const auto inputs = make_inputs(element::f32, input_shapes); OV_EXPECT_THROW(auto o = make_op(inputs, equation), AssertFailure, - HasSubstr("Output subscript of Einsum equation must contain one " - "ellipsis if ellipsis is met in any input subscript.")); + HasSubstr("Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript.")); } From 50c98c189d4e7204a2afab03f36c0d3791005b42 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Thu, 19 Dec 2024 18:07:55 +0000 Subject: [PATCH 02/47] Einsum decomposition broadcasting + ellipsis support Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 300 +++++++++++++++--- 1 file changed, 264 insertions(+), 36 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7955e37cfcda14..7d93928243b78d 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -9,6 +9,7 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/core/validation_util.hpp" +#include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/einsum.hpp" @@ -26,7 +27,7 @@ namespace { /// \brief Check if the EinsumDecomposition transformation is applicable to a given Einsum. -/// The transformation is applicable if input subscript does not have repeated labels and ellipsis. +/// The transformation is applicable if input subscript does not have repeated labels. /// /// \param subscript A subscript to check its format /// @@ -35,7 +36,7 @@ namespace { bool is_subscript_applicable(const std::string& subscript) { auto labels = ov::op::v7::Einsum::extract_labels(subscript); auto unique_labels = std::unordered_set(labels.begin(), labels.end()); - return std::find(labels.begin(), labels.end(), "...") == labels.end() && unique_labels.size() == labels.size(); + return unique_labels.size() == labels.size(); } /// \brief Compute einsum_path for a given Einsum node meaning that the (pseudo-)optimal @@ -174,7 +175,88 @@ void update_operands(ov::OutputVector& input_nodes, input_subscripts.erase(input_subscripts.begin() + input_ind1); input_subscripts.push_back(new_subscript); } +using LabelDimMap = std::unordered_map>; + +LabelDimMap compute_label_dim_map(const ov::Rank& input_rank, const std::string& input_subscript) { + static const std::string ellipsis = "..."; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); + const auto static_input_rank = input_rank.is_static(); + OPENVINO_ASSERT(static_input_rank || (std::find(labels.begin(), labels.end(), ellipsis) == labels.end()), + "Input rank cannot be dynamic in case of ellipsis in input subscript"); + const size_t input_rank_length = static_input_rank ? input_rank.get_length() : labels.size(); + OPENVINO_ASSERT(input_rank_length >= labels.size()); + const size_t num_broadcasted_dims = input_rank_length - labels.size() + 1; + OPENVINO_ASSERT(num_broadcasted_dims > 0); + + LabelDimMap resulted_map; + size_t current_dim = 0; + for (const auto& label : labels) { + if (label == ellipsis) { + std::vector label_dims(num_broadcasted_dims); + std::iota(label_dims.begin(), label_dims.end(), current_dim); + resulted_map[label] = label_dims; + current_dim += num_broadcasted_dims; + } else if (resulted_map.find(label) != resulted_map.end()) { + resulted_map[label].push_back(current_dim); + ++current_dim; + } else { + std::vector label_dims; + label_dims.push_back(current_dim); + resulted_map[label] = label_dims; + ++current_dim; + } + } + + return resulted_map; +} + +void compute_ranges(const ov::Rank& input_rank, + const std::string& input_subscript, + const std::vector& common_labels, + const std::vector& sep_labels, + const std::vector& reduced_labels, + size_t& common_begin, + size_t& common_end, + size_t& sep_begin, + size_t& sep_end, + size_t& reduced_begin, + size_t& reduced_end, + bool is_separated_first) { + auto label_to_dim_map = compute_label_dim_map(input_rank, input_subscript); + static const std::string ellipsis = "..."; + + size_t common_rank = common_labels.size(); + if (std::find(common_labels.begin(), common_labels.end(), ellipsis) != common_labels.end()) { + OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); + common_rank += label_to_dim_map[ellipsis].size() - 1; + } + + size_t sep_rank = sep_labels.size(); + if (std::find(sep_labels.begin(), sep_labels.end(), ellipsis) != sep_labels.end()) { + OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); + sep_rank += label_to_dim_map[ellipsis].size() - 1; + } + size_t reduced_rank = reduced_labels.size(); + if (std::find(reduced_labels.begin(), reduced_labels.end(), ellipsis) != reduced_labels.end()) { + OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); + reduced_rank += label_to_dim_map[ellipsis].size() - 1; + } + + common_begin = 0; + common_end = common_begin + common_rank; + if (is_separated_first) { + sep_begin = common_end; + sep_end = sep_begin + sep_rank; + reduced_begin = sep_end; + reduced_end = reduced_begin + reduced_rank; + } else { + reduced_begin = common_end; + reduced_end = reduced_begin + reduced_rank; + sep_begin = reduced_end; + sep_end = sep_begin + sep_rank; + } +} /// \brief Return input node with computed sub-shape defined by a range [s_begin;s_end) /// /// \param data_shape Input node that contains some tensor shape @@ -243,6 +325,84 @@ ov::Output unsqueeze_input(const ov::Output& input_node, return unsqueeze->output(0); } +ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, + ov::OutputVector& shapes_rhs, + ov::NodeVector& subgraph_nodes) { + // TODO - Refactor func to remove loop and duplicated Broadcast. + OPENVINO_ASSERT(shapes_lhs.size() == shapes_rhs.size()); + ov::OutputVector broadcasted_shape_nodes{shapes_lhs.size()}; + + for (size_t shp_i = 0; shp_i < shapes_lhs.size(); shp_i++) { + auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1}); + auto tmp_const_of_lhs_shp = + std::make_shared(const_1, shapes_lhs[shp_i], ov::op::BroadcastType::NUMPY); + auto tmp_const_of_broadcasted_shp = + std::make_shared(tmp_const_of_lhs_shp, + shapes_rhs[shp_i], + ov::op::BroadcastType::BIDIRECTIONAL); + auto broadcasted_shape = std::make_shared(tmp_const_of_broadcasted_shp); + broadcasted_shape_nodes[shp_i] = broadcasted_shape; + subgraph_nodes.insert(subgraph_nodes.end(), + {const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape}); + } + return broadcasted_shape_nodes; +} + +/// \brief Broadcast input node to the new shape specified by broadcasted sub-shapes of the common, +/// separate and reduced dimensions so that the broadcasted input has a format acceptable by Reshape MatMul +/// +/// \param input_node Input node to reshape +/// \param common_sub_shape A sub-shape corresponding to the broadcasted common dimensions +/// \param separate_sub_shape A sub-shape corresponding to the broadcasted separate dimensions +/// \param reduced_sub_shape_prod A product of the broadcasted separate dimensions sizes +/// \param is_separate_first true - the separate dimensions placed before reduced +/// dimensions, otherwise, it is after them +/// \param subgraph_nodes A vector of operation nodes that is included into +/// a sub-graph decomposing Einsum that is needed for copy_runtime_info +/// +/// \return Broadcasted input node +/// +ov::Output broadcast_input(const ov::Output& input_node, + const ov::OutputVector& common_sub_shape, + const ov::OutputVector& separate_sub_shape, + const ov::OutputVector& reduced_sub_shape, + bool is_separate_first, + ov::NodeVector& subgraph_nodes) { + ov::OutputVector new_shape_parts; + new_shape_parts.insert(new_shape_parts.end(), common_sub_shape.begin(), common_sub_shape.end()); + // form a new shape for input so that collapsed dimensions corresponding + // to the common, separate and reduced dimensions are placed in the correct order + if (is_separate_first) { + new_shape_parts.insert(new_shape_parts.end(), separate_sub_shape.begin(), separate_sub_shape.end()); + new_shape_parts.insert(new_shape_parts.end(), reduced_sub_shape.begin(), reduced_sub_shape.end()); + } else { + new_shape_parts.insert(new_shape_parts.end(), reduced_sub_shape.begin(), reduced_sub_shape.end()); + new_shape_parts.insert(new_shape_parts.end(), separate_sub_shape.begin(), separate_sub_shape.end()); + } + + // in case of scalar reshape is not needed + if (new_shape_parts.size() == 0) { + return input_node; + } + auto new_shape_op = std::make_shared(new_shape_parts, 0); + // if new shape is possible to compute on the shape infer stage, insert Constant node immediately + // in order to prevent repeated computing during constant-folding pass + std::shared_ptr reshaped_input_op; + if (auto new_shape_const = ov::util::get_constant_from_source(new_shape_op)) { + reshaped_input_op = + std::make_shared(input_node, new_shape_const, ov::op::BroadcastType::BIDIRECTIONAL); + subgraph_nodes.insert(subgraph_nodes.end(), {new_shape_const}); + } else { + reshaped_input_op = std::make_shared(input_node, + new_shape_op->output(0), + ov::op::BroadcastType::BIDIRECTIONAL); + subgraph_nodes.insert(subgraph_nodes.end(), {new_shape_op}); + } + + subgraph_nodes.insert(subgraph_nodes.end(), {reshaped_input_op}); + return reshaped_input_op->output(0); +} + /// \brief Reshape input node to the new shape specified by sub-shapes of the common, /// separate and reduced dimensions so that the reshaped input has a format acceptable by MatMul /// @@ -334,7 +494,7 @@ void transpose_input(ov::OutputVector& input_nodes, size_t input_ind, ov::NodeVector& subgraph_nodes) { // perform sanity check for arguments - auto num_inputs = input_nodes.size(); + const auto num_inputs = input_nodes.size(); OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); @@ -350,21 +510,22 @@ void transpose_input(ov::OutputVector& input_nodes, // find permutation that establishes bijection between the input subscript // and the required one - auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); - auto required_labels = ov::op::v7::Einsum::extract_labels(required_subscript); + const auto& input_node = input_nodes[input_ind]; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); + const auto required_labels = ov::op::v7::Einsum::extract_labels(required_subscript); OPENVINO_ASSERT(labels.size() == required_labels.size()); + const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); for (const auto& required_label : required_labels) { - auto it = std::find(labels.begin(), labels.end(), required_label); - OPENVINO_ASSERT(it != labels.end()); - int64_t found_index = static_cast(it - labels.begin()); - permutation.push_back(found_index); + const auto label_dims_it = label_dim_map.find(required_label); + OPENVINO_ASSERT(label_dims_it != label_dim_map.end()); + const auto& label_dims = label_dims_it->second; + permutation.insert(permutation.end(), label_dims.begin(), label_dims.end()); } // create a sub-graph for transposing into the required layout - const auto& input_node = input_nodes[input_ind]; - auto permutation_const = + const auto permutation_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{permutation.size()}, permutation); - auto transpose = std::make_shared(input_node, permutation_const); + const auto transpose = std::make_shared(input_node, permutation_const); // update a vector of inputs and input subscripts input_nodes[input_ind] = transpose->output(0); @@ -468,6 +629,11 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, const auto& input_node1 = input_nodes[input_ind1]; const auto& input_node2 = input_nodes[input_ind2]; + // extract diagonals in case repeated labels in the corresponding input subscripts + // TODO + // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); + // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); + // reduce dimensions for input operands if possible reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind2, subgraph_nodes); @@ -491,6 +657,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, std::vector common_labels_inds1, common_labels_inds2; std::vector separate_labels_inds1, separate_labels_inds2; std::vector reduced_labels_inds1, reduced_labels_inds2; + std::vector common_labels, sep_labels1, sep_labels2, reduced_labels; // +++++ for (size_t label_ind = 0; label_ind < labels1.size(); ++label_ind) { const auto& label = labels1[label_ind]; auto iter = std::find(labels2.begin(), labels2.end(), label); @@ -501,13 +668,16 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, if (is_dim_reduced) { reduced_labels_inds1.push_back(static_cast(label_ind)); reduced_labels_inds2.push_back(static_cast(iter - labels2.begin())); + reduced_labels.push_back(label); } else { common_labels_inds1.push_back(static_cast(label_ind)); common_labels_inds2.push_back(static_cast(iter - labels2.begin())); + common_labels.push_back(label); } } else { separate_part1 += label; separate_labels_inds1.push_back(static_cast(label_ind)); + sep_labels1.push_back(label); } } for (size_t label_ind = 0; label_ind < labels2.size(); ++label_ind) { @@ -516,6 +686,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, if (iter == labels1.end()) { separate_part2 += label; separate_labels_inds2.push_back(static_cast(label_ind)); + sep_labels2.push_back(label); } } @@ -601,26 +772,71 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; - int64_t common_dims_begin = 0; - int64_t common_dims_end = common_labels_inds1.size(); + + size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, + separate1_dims_end; + compute_ranges(input_node1.get_partial_shape().rank(), + input_subscript1, + common_labels, + sep_labels1, + reduced_labels, + common_dims_begin, + common_dims_end, + separate1_dims_begin, + separate1_dims_end, + reduced_dims_begin, + reduced_dims_end, + is_separate_first1); + + size_t common_dims_begin2, common_dims_end2, reduced_dims_begin2, reduced_dims_end2, separate2_dims_begin, + separate2_dims_end; + compute_ranges(input_node2.get_partial_shape().rank(), + input_subscript2, + common_labels, + sep_labels2, + reduced_labels, + common_dims_begin2, + common_dims_end2, + separate2_dims_begin, + separate2_dims_end, + reduced_dims_begin2, + reduced_dims_end2, + is_separate_first2); + + no_reshape_for_matmul1 = false; + no_reshape_for_matmul2 = false; + // // no_reshape_after_matmul = false; ov::OutputVector common_sub_shape, separate1_sub_shape, separate2_sub_shape; + if (no_reshape_for_matmul1 == false || no_reshape_for_matmul2 == false) { auto data_shape1 = std::make_shared(input_node1); + auto data_shape2 = std::make_shared(input_node2); common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); - int64_t reduced_dims_begin = (is_separate_first1 ? common_labels_inds1.size() + separate_labels_inds1.size() - : common_labels_inds1.size()); - int64_t reduced_dims_end = reduced_dims_begin + reduced_labels_inds1.size(); + auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); + OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); + common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); auto reduced_sub_shape_prod = compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - + auto reduced_sub_shape_prod2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); + auto reduced_sub_shape = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); + auto reduced_sub_shape2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); + + reduced_sub_shape_prod = + broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); + reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - int64_t separate1_dims_begin = - (is_separate_first1 ? common_labels_inds1.size() - : common_labels_inds1.size() + reduced_labels_inds1.size()); - int64_t separate1_dims_end = separate1_dims_begin + separate_labels_inds1.size(); separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - matmul_operand1 = reshape_input_for_matmul(input_node1, + auto broadcasted1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); + matmul_operand1 = reshape_input_for_matmul(broadcasted1, common_sub_shape, separate1_sub_shape, reduced_sub_shape_prod, @@ -629,14 +845,15 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, } if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - auto data_shape2 = std::make_shared(input_node2); - int64_t separate2_dims_begin = - (is_separate_first2 ? common_labels_inds2.size() - : common_labels_inds2.size() + reduced_labels_inds2.size()); - int64_t separate2_dims_end = separate2_dims_begin + separate_labels_inds2.size(); separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - matmul_operand2 = reshape_input_for_matmul(input_node2, + auto broadcasted2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); + matmul_operand2 = reshape_input_for_matmul(broadcasted2, common_sub_shape, separate2_sub_shape, reduced_sub_shape_prod, @@ -654,8 +871,11 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, // step 4. reshape back by unrolling dimensions corresponding to separate labels if needed // now dimensions corresponding to reduced labels are reduced by the MatMul operation - std::string resultant_subscript = - input_subscript1.substr(common_dims_begin, common_dims_end) + separate_part1 + separate_part2; + common_part = ""; + for (const auto& common_label : common_labels) { + common_part += common_label; + } + const std::string resultant_subscript = common_part + separate_part1 + separate_part2; if (no_reshape_after_matmul) { // this is a case when Reshape is not needed after MatMul operation // since there are no collapsed (or auxiliary added) separated dimensions @@ -667,12 +887,12 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, new_shape.insert(new_shape.end(), separate2_sub_shape.begin(), separate2_sub_shape.end()); auto result_shape_op = std::make_shared(new_shape, 0); - // if new shape is possible to compute on the shape infer stage, insert Constant node immediatelly + // if new shape is possible to compute on the shape infer stage, insert Constant node immediately // in order to prevent repeated computing during constant-folding pass std::shared_ptr result_op; if (auto new_shape_const = ov::util::get_constant_from_source(result_shape_op)) { result_op = std::make_shared(matmul->output(0), new_shape_const, false); - subgraph_nodes.insert(subgraph_nodes.end(), {new_shape_const}); + subgraph_nodes.insert(subgraph_nodes.end(), {result_shape_op, new_shape_const}); } else { result_op = std::make_shared(matmul->output(0), result_shape_op->output(0), false); subgraph_nodes.insert(subgraph_nodes.end(), {result_shape_op}); @@ -723,6 +943,12 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // and a vector of sub-graph nodes for copy_runtime_info ov::OutputVector input_nodes = einsum_node->input_values(); ov::NodeVector subgraph_nodes; + // check that the transformation is applicable + if (std::any_of(input_nodes.cbegin(), input_nodes.cend(), [](ov::Output node) { + return node.get_partial_shape().rank().is_dynamic(); + })) { + return false; + } // compute einsum path that is used to contract a pair of operands // in more optimal order @@ -739,13 +965,15 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { subgraph_nodes); } - // reduce dimensions for the remained input node OPENVINO_ASSERT(input_nodes.size() == 1); - reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); + // extract diagonal for the single operand + // TODO + // extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); + // reduce dimensions for the remained input node + reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // transpose dimensions to layout required by the output subscript transpose_input(input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); - // replace the original Einsum node with the last node from decomposing sub-graph // preserve the original node name auto last_node = input_nodes[0].get_node_shared_ptr(); From d3eac209cd424d07f17fe894ef8c986f760bf0bb Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Fri, 20 Dec 2024 11:21:32 +0000 Subject: [PATCH 03/47] Move broadcasting out of reshape conditional Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 108 +++++++++--------- 1 file changed, 51 insertions(+), 57 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7d93928243b78d..7ae13f270a03c5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -772,6 +772,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; + auto broadcasted_operand1 = input_node1; + auto broadcasted_operand2 = input_node2; size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, separate1_dims_end; @@ -803,65 +805,57 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_dims_end2, is_separate_first2); - no_reshape_for_matmul1 = false; - no_reshape_for_matmul2 = false; - // // no_reshape_after_matmul = false; ov::OutputVector common_sub_shape, separate1_sub_shape, separate2_sub_shape; - if (no_reshape_for_matmul1 == false || no_reshape_for_matmul2 == false) { - auto data_shape1 = std::make_shared(input_node1); - auto data_shape2 = std::make_shared(input_node2); - common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); - auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); - OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); - common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); - auto reduced_sub_shape_prod = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - auto reduced_sub_shape_prod2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); - auto reduced_sub_shape = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); - auto reduced_sub_shape2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); - - reduced_sub_shape_prod = - broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); - reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); - if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - separate1_sub_shape = - compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - auto broadcasted1 = broadcast_input(input_node1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape, - is_separate_first1, - subgraph_nodes); - matmul_operand1 = reshape_input_for_matmul(broadcasted1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape_prod, - is_separate_first1, - subgraph_nodes); - } - - if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - separate2_sub_shape = - compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - auto broadcasted2 = broadcast_input(input_node2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape, - is_separate_first2, - subgraph_nodes); - matmul_operand2 = reshape_input_for_matmul(broadcasted2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape_prod, - is_separate_first2, - subgraph_nodes); - subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2}); - } - subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1}); + auto data_shape1 = std::make_shared(input_node1); + auto data_shape2 = std::make_shared(input_node2); + subgraph_nodes.insert(subgraph_nodes.end(), {data_shape1}); + subgraph_nodes.insert(subgraph_nodes.end(), {data_shape2}); + common_sub_shape = compute_sub_shape(data_shape1, common_dims_begin, common_dims_end, subgraph_nodes); + auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); + OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); + common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); + auto reduced_sub_shape_prod = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); + auto reduced_sub_shape_prod2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); + auto reduced_sub_shape = + compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); + auto reduced_sub_shape2 = + compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); + + reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); + reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); + separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); + broadcasted_operand1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); + separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); + broadcasted_operand2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); + if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { + matmul_operand1 = reshape_input_for_matmul(broadcasted_operand1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape_prod, + is_separate_first1, + subgraph_nodes); + } + + if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { + matmul_operand2 = reshape_input_for_matmul(broadcasted_operand2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape_prod, + is_separate_first2, + subgraph_nodes); } // step 3. apply MatMul operation for formatted inputs From 0ec79742ec4f12fba96ddb6470cee59a8ded7487 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Tue, 7 Jan 2025 11:44:24 +0000 Subject: [PATCH 04/47] Initial support for repeated labels Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 224 ++++++++++++++---- 1 file changed, 183 insertions(+), 41 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 7ae13f270a03c5..067e73e3a0326c 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -12,33 +12,25 @@ #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/divide.hpp" #include "openvino/op/einsum.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" #include "openvino/op/reduce_prod.hpp" #include "openvino/op/reduce_sum.hpp" #include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/strided_slice.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" namespace { -/// \brief Check if the EinsumDecomposition transformation is applicable to a given Einsum. -/// The transformation is applicable if input subscript does not have repeated labels. -/// -/// \param subscript A subscript to check its format -/// -/// \return true - applicable, false - not applicable -/// -bool is_subscript_applicable(const std::string& subscript) { - auto labels = ov::op::v7::Einsum::extract_labels(subscript); - auto unique_labels = std::unordered_set(labels.begin(), labels.end()); - return unique_labels.size() == labels.size(); -} - /// \brief Compute einsum_path for a given Einsum node meaning that the (pseudo-)optimal /// order of operands contraction in terms of performance and memory consumption /// @@ -595,6 +587,167 @@ void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); } +ov::Output build_identity(const ov::Output& input_node, + const std::vector& repeated_label_dims, + ov::NodeVector& subgraph_nodes) { + OPENVINO_ASSERT(repeated_label_dims.size() > 1); + const auto input_shape = std::make_shared(input_node); + const auto repeated_label_indices = + ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); + const auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); + const auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); + const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); + const auto reduced_dimension = std::make_shared(repeated_dimensions, const_0, const_0); + const auto reduced_dimension_min_1 = std::make_shared(reduced_dimension, const_1); + + const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); + const auto reduced_size_min_1 = std::make_shared(reduced_size, const_1); + const auto step_size = std::make_shared(reduced_size_min_1, reduced_dimension_min_1); + const auto range = std::make_shared(const_0, reduced_dimension, const_1, ov::element::i64); + const auto steps = std::make_shared(range, step_size); + const auto zeros = std::make_shared(const_0, reduced_size); + const auto reduced_dimension_1d = std::make_shared(reduced_dimension, const_0); + const auto ones = std::make_shared(const_1, reduced_dimension_1d); + const auto eye_flattened = std::make_shared(zeros, steps, ones, const_0); + + const auto identity_rank = std::make_shared(input_shape); + const auto ones_of_input_shape_rank = std::make_shared(const_1, identity_rank); + const auto identity_shape = std::make_shared(ones_of_input_shape_rank, + repeated_label_indices, + repeated_dimensions, + const_0); + const auto identity = std::make_shared(eye_flattened, identity_shape, false); + const auto identity_cvt = std::make_shared(identity, input_node.get_element_type()); + subgraph_nodes.insert(subgraph_nodes.end(), + {input_shape, + repeated_label_indices, + const_0, + const_1, + repeated_dimensions, + reduced_dimension, + reduced_dimension_min_1, + reduced_size, + reduced_size_min_1, + step_size, + range, + steps, + zeros, + reduced_dimension_1d, + ones, + eye_flattened, + identity_rank, + ones_of_input_shape_rank, + identity_shape, + identity, + identity_cvt}); + return subgraph_nodes.back(); +} + +ov::Output build_multi_identity(ov::pass::EinsumDecomposition* einsum_decompose_ptr, + const ov::Output& input_node, + const std::vector& repeated_labels, + const LabelDimMap& label_dim_map, + ov::NodeVector& subgraph_nodes) { + OPENVINO_ASSERT(repeated_labels.size() > 0); + + const auto get_identity = [&](size_t idx) { + const auto repeated_label_dims = label_dim_map.find(repeated_labels[idx]); + OPENVINO_ASSERT(repeated_label_dims != label_dim_map.end()); + return build_identity(input_node, repeated_label_dims->second, subgraph_nodes); + }; + + // initially set multi-identity with identity for the first repeated label + const auto multi_identity = get_identity(0); + for (size_t label_ind = 1; label_ind < repeated_labels.size(); ++label_ind) { + const auto identity = get_identity(label_ind); + const auto mul = + std::make_shared(multi_identity, identity, ov::op::AutoBroadcastType::NUMPY); + subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + } + + return subgraph_nodes.back(); +} + +/// \brief Helper function to fill in the data needed for diagonal extraction - result shape +/// and subscript, repeated labels, axes to reduce. +/// +void prepare_diagonal_extraction_data(const std::string& input_subscript, + const LabelDimMap& label_dim_map, + std::string& resultant_subscript, + std::vector& repeated_labels, + ov::AxisSet& reduced_axes) { + static const std::string ellipsis = "..."; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); + + for (const auto& label : labels) { + if (resultant_subscript.find(label) != std::string::npos) { + continue; + } + + const auto dims_it = label_dim_map.find(label); + OPENVINO_ASSERT(dims_it != label_dim_map.end()); + + auto dims = dims_it->second; + const auto dims_size = dims.size(); + OPENVINO_ASSERT(dims_size > 0); + + if (label != ellipsis && dims_size > 1) { + // repeated label is found + for (size_t dim_ind = 1; dim_ind < dims_size; ++dim_ind) { + reduced_axes.insert(dims[dim_ind]); + } + // save only the first dimension corresponding to the repeated label + dims = {dims[0]}; + repeated_labels.push_back(label); + } + resultant_subscript += label; + } +} + +void extract_diagonal(ov::pass::EinsumDecomposition* einsum_decompose_ptr, + ov::OutputVector& inputs, + std::vector& input_subscripts, + size_t input_ind, + ov::NodeVector& subgraph_nodes) { + // perform sanity check for arguments + const auto num_inputs = inputs.size(); + OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); + OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); + + const auto& input_node = inputs[input_ind]; + const auto& input_subscript = input_subscripts[input_ind]; + + const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); + std::string resultant_subscript; + std::vector repeated_labels; + ov::AxisSet reduced_axes; + prepare_diagonal_extraction_data(input_subscript, + label_dim_map, + resultant_subscript, + repeated_labels, + reduced_axes); + + if (repeated_labels.size() == 0) { + return; + } + const auto multi_identity = + build_multi_identity(einsum_decompose_ptr, input_node, repeated_labels, label_dim_map, subgraph_nodes); + + // multiply both operands with broadcasting + const auto mul = + std::make_shared(input_node, multi_identity, ov::op::AutoBroadcastType::NUMPY); + subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + + const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; + const auto axes_const = + ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes_vec); + const auto reduce_sum = std::make_shared(mul->output(0), axes_const, false); + subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); + + inputs[input_ind] = reduce_sum->output(0); + input_subscripts[input_ind] = resultant_subscript; +} + /// \brief Contract two inputs of Einsum operation according to equation. /// The result of the contraction is appended into input_nodes along with its subscript. /// The input nodes for these two operands are removed from input_nodes along with their input @@ -630,9 +783,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, const auto& input_node2 = input_nodes[input_ind2]; // extract diagonals in case repeated labels in the corresponding input subscripts - // TODO - // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); - // extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); + extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); + extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); // reduce dimensions for input operands if possible reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); @@ -772,8 +924,6 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; - auto broadcasted_operand1 = input_node1; - auto broadcasted_operand2 = input_node2; size_t common_dims_begin, common_dims_end, reduced_dims_begin, reduced_dims_end, separate1_dims_begin, separate1_dims_end; @@ -827,21 +977,21 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); - broadcasted_operand1 = broadcast_input(input_node1, - common_sub_shape, - separate1_sub_shape, - reduced_sub_shape, - is_separate_first1, - subgraph_nodes); + matmul_operand1 = broadcast_input(input_node1, + common_sub_shape, + separate1_sub_shape, + reduced_sub_shape, + is_separate_first1, + subgraph_nodes); separate2_sub_shape = compute_sub_shape(data_shape2, separate2_dims_begin, separate2_dims_end, subgraph_nodes); - broadcasted_operand2 = broadcast_input(input_node2, - common_sub_shape, - separate2_sub_shape, - reduced_sub_shape, - is_separate_first2, - subgraph_nodes); + matmul_operand2 = broadcast_input(input_node2, + common_sub_shape, + separate2_sub_shape, + reduced_sub_shape, + is_separate_first2, + subgraph_nodes); if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { - matmul_operand1 = reshape_input_for_matmul(broadcasted_operand1, + matmul_operand1 = reshape_input_for_matmul(matmul_operand1, common_sub_shape, separate1_sub_shape, reduced_sub_shape_prod, @@ -850,7 +1000,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, } if (no_reshape_for_matmul2 == false || no_reshape_after_matmul == false) { - matmul_operand2 = reshape_input_for_matmul(broadcasted_operand2, + matmul_operand2 = reshape_input_for_matmul(matmul_operand2, common_sub_shape, separate2_sub_shape, reduced_sub_shape_prod, @@ -926,13 +1076,6 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { std::string output_subscript; ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript); - // check that the transformation is applicable - if (std::any_of(input_subscripts.cbegin(), input_subscripts.cend(), [](const std::string& subscript) { - return is_subscript_applicable(subscript) == false; - })) { - return false; - } - // create a list of input nodes with preserving their order // and a vector of sub-graph nodes for copy_runtime_info ov::OutputVector input_nodes = einsum_node->input_values(); @@ -962,8 +1105,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { OPENVINO_ASSERT(input_nodes.size() == 1); // extract diagonal for the single operand - // TODO - // extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); + extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); // reduce dimensions for the remained input node reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // transpose dimensions to layout required by the output subscript From fa041ca75dfa456ecb71e72115ad257beb3896c7 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Wed, 8 Jan 2025 11:11:18 +0000 Subject: [PATCH 05/47] Remove xfail for onnx einsum test Signed-off-by: MATEUSZ MIKOLAJCZYK --- src/frontends/onnx/tests/__init__.py | 1 - src/frontends/onnx/tests/tests_python/test_backend.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/frontends/onnx/tests/__init__.py b/src/frontends/onnx/tests/__init__.py index fdf1295dfd1dbe..bd5477dbcf8a13 100644 --- a/src/frontends/onnx/tests/__init__.py +++ b/src/frontends/onnx/tests/__init__.py @@ -120,7 +120,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE") xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadcast_cpu - " "Not equal to tolerance") -xfail_issue_58033 = xfail_test(reason="Einsum operation misses support for complex ellipsis equations") xfail_issue_58676 = xfail_test(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07") skip_issue_58676 = pytest.mark.skip(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07") xfail_issue_onnx_models_140 = xfail_test(reason="https://github.com/onnx/models/issues/140") diff --git a/src/frontends/onnx/tests/tests_python/test_backend.py b/src/frontends/onnx/tests/tests_python/test_backend.py index 39b9788d720af3..487454675ac50e 100644 --- a/src/frontends/onnx/tests/tests_python/test_backend.py +++ b/src/frontends/onnx/tests/tests_python/test_backend.py @@ -32,7 +32,6 @@ xfail_issue_73538, xfail_issue_48052, xfail_issue_52463, - xfail_issue_58033, xfail_issue_63033, xfail_issue_63036, xfail_issue_63043, @@ -292,7 +291,6 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu", "OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu", ), - (xfail_issue_58033, "OnnxBackendNodeModelTest.test_einsum_batch_diagonal_cpu"), ( xfail_issue_63033, "OnnxBackendNodeModelTest.test_batchnorm_epsilon_training_mode_cpu", From 6796536d7dadca42999a6018366f56fbf71a0ec7 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Wed, 8 Jan 2025 13:27:41 +0000 Subject: [PATCH 06/47] Remove Einsum xfail for torch HF tests Signed-off-by: MATEUSZ MIKOLAJCZYK --- tests/model_hub_tests/pytorch/hf_transformers_models | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/model_hub_tests/pytorch/hf_transformers_models b/tests/model_hub_tests/pytorch/hf_transformers_models index 3d05a430a16671..c861c0cb6c64ee 100644 --- a/tests/model_hub_tests/pytorch/hf_transformers_models +++ b/tests/model_hub_tests/pytorch/hf_transformers_models @@ -88,7 +88,7 @@ hifigan,microsoft/speecht5_hifigan,xfail,Load error: The size of tensor a (100) hubert,facebook/hubert-large-ls960-ft hybridbert,gokuls/bert_12_layer_model_v1 ibert,DunnBC22/ibert-roberta-base-Abusive_Or_Threatening_Speech -idefics,HuggingFaceM4/tiny-random-idefics,xfail,aten::einsum Different input dimensions indicated by the same labels for Einsum must be compatible +idefics,HuggingFaceM4/tiny-random-idefics imagegpt,openai/imagegpt-small informer,huggingface/informer-tourism-monthly,xfail,Load error: mat1 and mat2 shapes cannot be multiplied instructblip,Salesforce/instructblip-vicuna-7b @@ -106,7 +106,7 @@ levit,facebook/levit-128S,xfail,Trace error: Cannot insert a Tensor that require lilt,nielsr/lilt-xlm-roberta-base llama_with_landmark,Leooyii/Landmark_512_Slimpajama_1B longformer,allenai/longformer-base-4096 -longt5,pszemraj/long-t5-tglobal-base-16384-book-summary,xfail,(CVS-148676) Compile error: unsupported Einsum +longt5,pszemraj/long-t5-tglobal-base-16384-book-summary luke,oshizo/sbert-jsnli-luke-japanese-base-lite lxmert,unc-nlp/lxmert-base-uncased m2m_100,facebook/nllb-200-distilled-600M @@ -119,7 +119,7 @@ mbart,facebook/mbart-large-50-many-to-many-mmt mctct,speechbrain/m-ctc-t-large mega,Bingsu/mega-150m-arch,xfail,Trace error: Cannot insert a Tensor that requires grad as a constant megatron-bert,UFNLP/gatortron-base -mgp-str,alibaba-damo/mgp-str-base,xfail,(CVS-148676) Compile error: unsupported Einsum +mgp-str,alibaba-damo/mgp-str-base mobilebert,google/mobilebert-uncased mobilenet_v1,google/mobilenet_v1_0.75_192 mobilenet_v2,google/mobilenet_v2_1.0_224 From be8400c7f95f8a8aa3fc216882ccd0a0e60d38d2 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Thu, 16 Jan 2025 12:42:26 +0000 Subject: [PATCH 07/47] Update transpose reshape elimination for MatMul to handle broadcast from Einsum Signed-off-by: MATEUSZ MIKOLAJCZYK --- ...anspose_reshape_elimination_for_matmul.cpp | 46 +++++++++++++++++-- ...anspose_reshape_elimination_for_matmul.cpp | 14 +++++- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp b/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp index d3eff542d6b7af..caac91de147ab6 100644 --- a/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/transpose_reshape_elimination_for_matmul.cpp @@ -9,10 +9,12 @@ #include "itt.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/op/broadcast.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/transpose.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" namespace { @@ -124,9 +126,16 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa auto transpose_before_pattern = ov::pass::pattern::wrap_type({input_2_pattern, const_transpose_before_pattern}); + auto const_optional_broadcast_before_pattern = ov::pass::pattern::wrap_type(); + auto optional_broadcast_before_pattern = ov::pass::pattern::wrap_type( + {transpose_before_pattern, const_optional_broadcast_before_pattern}); + + auto transpose_or_transpose_broadcast = std::make_shared( + OutputVector{transpose_before_pattern, optional_broadcast_before_pattern}); + auto const_reshape_before_pattern = ov::pass::pattern::wrap_type(); - auto reshape_before_pattern = - ov::pass::pattern::wrap_type({transpose_before_pattern, const_reshape_before_pattern}); + auto reshape_before_pattern = ov::pass::pattern::wrap_type( + {transpose_or_transpose_broadcast, const_reshape_before_pattern}); auto matmul_pattern = ov::pass::pattern::wrap_type({input_1_pattern, reshape_before_pattern}); @@ -181,8 +190,37 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa // transposes if (!check_transposes(transpose_before_order, transpose_after_order, transposed_b)) return false; - - const auto new_matmul = std::make_shared(input_1, input_2, transposed_a, false); + auto matmul_2_input = input_2; + // for einsum decomposition, check if broadcast exist and if so, reorder target shape based on transpose + if (pattern_value_map.count(optional_broadcast_before_pattern)) { + auto broadcast_before = ov::as_type_ptr( + pattern_value_map.at(optional_broadcast_before_pattern).get_node_shared_ptr()); + if (!broadcast_before) { + return false; + } + auto broadcast_before_constant = + ov::as_type_ptr(broadcast_before->get_input_node_shared_ptr(1)); + if (!broadcast_before_constant) { + return false; + } + auto broadcast_shape_after_transpose = broadcast_before_constant->cast_vector(); + if (broadcast_shape_after_transpose.size() != transpose_before_order.size()) { + return false; + } + std::vector broadcast_shape_no_transpose; + broadcast_shape_no_transpose.reserve(broadcast_shape_after_transpose.size()); + for (auto idx : transpose_before_order) { + broadcast_shape_no_transpose.push_back(broadcast_shape_after_transpose[idx]); + } + auto broadcast_shape_no_transpose_constant = + ov::op::v0::Constant::create(element::i64, + broadcast_before_constant->get_shape(), + broadcast_shape_no_transpose); + matmul_2_input = broadcast_before->clone_with_new_inputs({input_2, broadcast_shape_no_transpose_constant}); + copy_runtime_info(broadcast_before, matmul_2_input.get_node_shared_ptr()); + } + + const auto new_matmul = std::make_shared(input_1, matmul_2_input, transposed_a, false); new_matmul->set_friendly_name(transpose_after->get_friendly_name()); copy_runtime_info({transpose_before, reshape_before, matmul, reshape_after, transpose_after}, new_matmul); replace_node(transpose_after, new_matmul); diff --git a/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp b/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp index ea57598a16c653..1f8d376f86800d 100644 --- a/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp +++ b/src/common/transformations/tests/common_optimizations/transpose_reshape_elimination_for_matmul.cpp @@ -138,11 +138,21 @@ TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_Einsum) { { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); + auto broadcast_shape_constant_1 = + std::make_shared(element::i64, Shape{data_shape_1.size()}, data_shape_1); + auto broadcast_shape_constant_2 = + std::make_shared(element::i64, Shape{data_shape_2.size()}, data_shape_2); + auto broadcast_1 = std::make_shared(data_1, + broadcast_shape_constant_1, + ov::op::BroadcastType::BIDIRECTIONAL); + auto broadcast_2 = std::make_shared(data_2, + broadcast_shape_constant_2, + ov::op::BroadcastType::BIDIRECTIONAL); // for some cases Reshape may be first input for Matmul auto shape_constant = std::make_shared(element::i64, Shape{data_shape_1.size()}, data_shape_1); - auto reshape = std::make_shared(data_1, shape_constant, false); - auto matmul = std::make_shared(reshape, data_2, false, false); + auto reshape = std::make_shared(broadcast_1, shape_constant, false); + auto matmul = std::make_shared(reshape, broadcast_2, false, false); model_ref = std::make_shared(NodeVector{matmul}, ParameterVector{data_1, data_2}); } } From 81b5d3907e601c42a8b5b629b95daf9cba7128a8 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Fri, 17 Jan 2025 18:19:41 +0000 Subject: [PATCH 08/47] Initial Einsum update to handle ellipsis label without dimensions Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 94 ++++++++++++++----- src/core/reference/src/op/einsum.cpp | 45 ++++++++- .../include/einsum_shape_inference.hpp | 4 + 3 files changed, 112 insertions(+), 31 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 3a8f9b4c0c5ffb..0aea695fab1fb1 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -896,32 +896,6 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, is_separate_first2); transpose_input(input_nodes, input_subscripts, int_subscript2, input_ind2, subgraph_nodes); - // step 2. reshape both operands so that separate labels and reduced labels are represented - // with just one dimension this is needed by MatMul operation requirement to operands - // format. For example, the shape must be in a format [B1, ..., Bm, X1, Y] or [B1, ..., Bm, - // Y, X2], where B1, ..., Bm are common dimensions, X1 and X2 are collapsed dimensions - // for separate labels and Y is collapsed dimension for reduced labels - // this step is not needed for the operand if it satisfies to one of the requirements: - // 1. there is just one separate dimension and just one reduced dimension - // 2. there is no separate dimension, no common dimensions, and just one reduced dimension - bool no_reshape_for_matmul1 = - (reduced_labels_inds1.size() == 1 && separate_labels_inds1.size() == 1) || - (reduced_labels_inds1.size() == 1 && common_labels_inds1.size() == 0 && separate_labels_inds1.size() == 0); - bool no_reshape_for_matmul2 = - (reduced_labels_inds2.size() == 1 && separate_labels_inds2.size() == 1) || - (reduced_labels_inds2.size() == 1 && common_labels_inds2.size() == 0 && separate_labels_inds2.size() == 0); - // reshape back after MatMul is not needed if one of two requrements satisfies for both operands: - // 1. there is just one separate dimension - // 2. there is no separate dimension and no common dimensions present. - // If there is no separate dimension and common dimensions present, reshape is needed - // because auxiliary separate dimension has been added by Unsqueeze operation - // in the purpose for MatMul - bool no_reshape_back1 = - (separate_labels_inds1.size() == 1) || (common_labels_inds1.size() == 0 && separate_labels_inds1.size() == 0); - bool no_reshape_back2 = - (separate_labels_inds2.size() == 1) || (common_labels_inds2.size() == 0 && separate_labels_inds2.size() == 0); - bool no_reshape_after_matmul = no_reshape_back1 && no_reshape_back2; - auto matmul_operand1 = input_node1; auto matmul_operand2 = input_node2; @@ -990,6 +964,34 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, reduced_sub_shape, is_separate_first2, subgraph_nodes); + + // step 2. reshape both operands so that separate labels and reduced labels are represented + // with just one dimension this is needed by MatMul operation requirement to operands + // format. For example, the shape must be in a format [B1, ..., Bm, X1, Y] or [B1, ..., Bm, + // Y, X2], where B1, ..., Bm are common dimensions, X1 and X2 are collapsed dimensions + // for separate labels and Y is collapsed dimension for reduced labels + // this step is not needed for the operand if it satisfies to one of the requirements: + // 1. there is just one separate dimension and just one reduced dimension + // 2. there is no separate dimension, no common dimensions, and just one reduced dimension + const auto common_labels1_size = common_dims_end - common_dims_begin; + const auto common_labels2_size = common_dims_end2 - common_dims_begin2; + const auto reduced_labels1_size = reduced_dims_end - reduced_dims_begin; + const auto reduced_labels2_size = reduced_dims_end2 - reduced_dims_begin2; + const auto separate_labels1_size = separate1_dims_end - separate1_dims_begin; + const auto separate_labels2_size = separate2_dims_end - separate2_dims_begin; + bool no_reshape_for_matmul1 = (reduced_labels1_size == 1 && separate_labels1_size == 1) || + (reduced_labels1_size == 1 && common_labels1_size == 0 && separate_labels1_size == 0); + bool no_reshape_for_matmul2 = (reduced_labels2_size == 1 && separate_labels2_size == 1) || + (reduced_labels2_size == 1 && common_labels2_size == 0 && separate_labels2_size == 0); + // reshape back after MatMul is not needed if one of two requirements satisfies for both operands: + // 1. there is just one separate dimension + // 2. there is no separate dimension and no common dimensions present. + // If there is no separate dimension and common dimensions present, reshape is needed + // because auxiliary separate dimension has been added by Unsqueeze operation + // in the purpose for MatMul + bool no_reshape_back1 = (separate_labels1_size == 1) || (common_labels1_size == 0 && separate_labels1_size == 0); + bool no_reshape_back2 = (separate_labels2_size == 1) || (common_labels2_size == 0 && separate_labels2_size == 0); + bool no_reshape_after_matmul = no_reshape_back1 && no_reshape_back2; if (no_reshape_for_matmul1 == false || no_reshape_after_matmul == false) { matmul_operand1 = reshape_input_for_matmul(matmul_operand1, common_sub_shape, @@ -1091,6 +1093,46 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // in more optimal order auto einsum_path = compute_einsum_path(einsum_node); + // fix inputs where ellipsis does not contain any dimensions + std::vector ellipsis_inputs(input_nodes.size(), false); + std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); + static const std::string ellipsis = "..."; + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); + if (!ellipsis_inputs[inp_iter] || + (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { + no_ellipsis_or_empty_inputs[inp_iter] = true; + } + } + if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { + return inp; + })) { + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { + return inp; + })) { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (input_subscripts[inp_iter].find("...") != std::string::npos) { + input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } + } + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { + auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); + std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; + input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); + } + } + } + // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { contract_two_inputs(this, diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index d16e500b40e2fe..4457b628f670be 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -853,9 +853,7 @@ void contract_two_inputs(ov::TensorVector& inputs, PartialShape common_sub_shape1 = compute_sub_shape(input_shape1, common_dims_begin, common_dims_end); PartialShape common_sub_shape2 = compute_sub_shape(input_shape2, common_dims_begin2, common_dims_end2); - PartialShape reduced_sub_shape_prod = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end, true); PartialShape reduced_sub_shape = compute_sub_shape(input_shape1, reduced_dims_begin, reduced_dims_end); - Shape reduced_sub_shape_prod2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2, true); Shape reduced_sub_shape2 = compute_sub_shape(input_shape2, reduced_dims_begin2, reduced_dims_end2); Shape separate1_sub_shape = compute_sub_shape(input_shape1, separate1_dims_begin, separate1_dims_end); Shape separate2_sub_shape = compute_sub_shape(input_shape2, separate2_dims_begin, separate2_dims_end); @@ -865,7 +863,7 @@ void contract_two_inputs(ov::TensorVector& inputs, // reference::broadcast() PartialShape::broadcast_merge_into(common_sub_shape1, common_sub_shape2, op::AutoBroadcastType::NUMPY); PartialShape::broadcast_merge_into(reduced_sub_shape, reduced_sub_shape2, op::AutoBroadcastType::NUMPY); - PartialShape::broadcast_merge_into(reduced_sub_shape_prod, reduced_sub_shape_prod2, op::AutoBroadcastType::NUMPY); + Shape reduced_sub_shape_prod = {shape_size(reduced_sub_shape.get_shape())}; Shape common_sub_shape = common_sub_shape1.get_shape(); broadcast_input(inputs, input_ind1, @@ -883,13 +881,13 @@ void contract_two_inputs(ov::TensorVector& inputs, ov::Tensor matmul_operand1 = reshape_input_for_matmul(input1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape_prod.get_shape(), + reduced_sub_shape_prod, is_separate_first1); ov::Tensor matmul_operand2 = reshape_input_for_matmul(input2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape_prod.get_shape(), + reduced_sub_shape_prod, is_separate_first2); // step 3. apply MatMul operation for formatted inputs @@ -941,6 +939,43 @@ void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, cons auto einsum_path = compute_einsum_path(num_inputs); ov::TensorVector int_inputs = inputs; + std::vector ellipsis_inputs(inputs.size(), false); + std::vector no_ellipsis_or_empty_inputs(inputs.size(), false); + static const std::string ellipsis = "..."; + for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); + if (!ellipsis_inputs[inp_iter] || (inputs[inp_iter].get_shape().size() == (labels.size() - 1))) { + no_ellipsis_or_empty_inputs[inp_iter] = true; + } + } + if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { + return inp; + })) { + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { + return inp; + })) { + for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { + if (input_subscripts[inp_iter].find("...") != std::string::npos) { + input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } + } + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else { + for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { + if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { + auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); + std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; + int_inputs[inp_iter] = unsqueeze_input(inputs[inp_iter], ellipsis_idx); + } + } + } // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index 2a7cd60369261e..1ee471117d6872 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -101,6 +101,10 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s auto& output_shape = output_shapes[0]; for (auto const& output_label : output_labels) { + if (output_label == "..." && label_to_shape.find(output_label) == label_to_shape.end()) { + // Output labels may contain ellipsis that does not cover any dimensions. + continue; + } NODE_VALIDATION_CHECK(op, label_to_shape.find(output_label) != label_to_shape.end(), "Label in output subscript of Einsum equation must enter at least " From 28b579ac85d0fc0a8ce23e6fb1c37e1da09961a9 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Mon, 20 Jan 2025 15:35:11 +0000 Subject: [PATCH 09/47] Update reduce_input in einsum common decomposition Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 0aea695fab1fb1..8fde064b2214af 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -546,38 +546,46 @@ void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, size_t input_ind, ov::NodeVector& subgraph_nodes) { // perform sanity check for arguments - auto num_inputs = input_nodes.size(); + const auto num_inputs = input_nodes.size(); OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); - std::vector reduced_axes; - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[input_ind]); + const auto& input_node = input_nodes[input_ind]; + const auto& input_subscript = input_subscripts[input_ind]; + + // compute output shape and axes to reduce + std::set reduced_axes; + const auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[input_ind]); + auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); std::string new_input_subscript = ""; - for (size_t dim_ind = 0; dim_ind < labels.size(); ++dim_ind) { - const auto& label = labels[dim_ind]; + for (const auto& label : labels) { // check if the current label is met in the other input subscripts // or the output subscript - bool is_dim_reduced = is_dimension_reduced(input_subscripts, output_subscript, label, {input_ind}); + const bool is_dim_reduced = is_dimension_reduced(input_subscripts, output_subscript, label, {input_ind}); + + OPENVINO_ASSERT(label_dim_map.find(label) != label_dim_map.end()); + const auto& label_dims = label_dim_map[label]; // if label is not met, dimension corresponding to the label is to reduce if (is_dim_reduced) { - reduced_axes.push_back(dim_ind); + reduced_axes.insert(label_dims.begin(), label_dims.end()); } else { new_input_subscript += label; } } - if (reduced_axes.size() == 0) { + if (reduced_axes.empty()) { // there is no axis to reduce return; } // reduce by summed up elements along dimension for which label is met just once - const auto& input_node = input_nodes[input_ind]; - auto axes_const = - ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes); - auto reduce_sum = einsum_decompose_ptr->register_new_node(input_node, axes_const, false); + const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; + const auto axes_const = + ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes_vec); + const auto reduce_sum = + einsum_decompose_ptr->register_new_node(input_node, axes_const, false); // update a vector of inputs and input subscripts input_nodes[input_ind] = reduce_sum->output(0); From 33acf2ecedfe41ac75d09575f78434e2b0c3d34f Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Mon, 20 Jan 2025 16:26:50 +0000 Subject: [PATCH 10/47] Fix broadcasting of reduced part for reshape Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../op_conversions/einsum_decomposition.cpp | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 8fde064b2214af..f21cfcf74aafc5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -412,7 +412,7 @@ ov::Output broadcast_input(const ov::Output& input_node, ov::Output reshape_input_for_matmul(const ov::Output& input_node, const ov::OutputVector& common_sub_shape, const ov::OutputVector& separate_sub_shape, - const ov::OutputVector& reduced_sub_shape_prod, + const ov::OutputVector& reduced_sub_shape, bool is_separate_first, ov::NodeVector& subgraph_nodes) { ov::OutputVector new_shape_parts; @@ -436,9 +436,17 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ separate_parts.push_back(separate_shape_prod->output(0)); subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, separate_shape_prod}); } + ov::OutputVector reduced_sub_shape_prod; + auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); + for (auto sub_shape : reduced_sub_shape) { + auto product = std::make_shared(sub_shape, const_0, true); + subgraph_nodes.insert(subgraph_nodes.end(), {const_0, product}); + reduced_sub_shape_prod.push_back(product->output(0)); + } // form a new shape for input so that collapsed dimensions corresponding // to the common, separate and reduced dimensions are placed in the correct order + if (is_separate_first) { new_shape_parts.insert(new_shape_parts.end(), separate_parts.begin(), separate_parts.end()); new_shape_parts.insert(new_shape_parts.end(), reduced_sub_shape_prod.begin(), reduced_sub_shape_prod.end()); @@ -454,7 +462,7 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ auto new_shape_op = std::make_shared(new_shape_parts, 0); - // if new shape is possible to compute on the shape infer stage, insert Constant node immediatelly + // if new shape is possible to compute on the shape infer stage, insert Constant node immediately // in order to prevent repeated computing during constant-folding pass std::shared_ptr reshaped_input_op; if (auto new_shape_const = ov::util::get_constant_from_source(new_shape_op)) { @@ -947,17 +955,13 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, auto common_sub_shape2 = compute_sub_shape(data_shape2, common_dims_begin2, common_dims_end2, subgraph_nodes); OPENVINO_ASSERT(common_sub_shape.size() == common_sub_shape2.size()); common_sub_shape = broadcast_merge_shapes(common_sub_shape, common_sub_shape2, subgraph_nodes); - auto reduced_sub_shape_prod = - compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, true); - auto reduced_sub_shape_prod2 = - compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, true); auto reduced_sub_shape = compute_sub_shape(data_shape1, reduced_dims_begin, reduced_dims_end, subgraph_nodes, false); auto reduced_sub_shape2 = compute_sub_shape(data_shape2, reduced_dims_begin2, reduced_dims_end2, subgraph_nodes, false); - reduced_sub_shape_prod = broadcast_merge_shapes(reduced_sub_shape_prod, reduced_sub_shape_prod2, subgraph_nodes); reduced_sub_shape = broadcast_merge_shapes(reduced_sub_shape, reduced_sub_shape2, subgraph_nodes); + separate1_sub_shape = compute_sub_shape(data_shape1, separate1_dims_begin, separate1_dims_end, subgraph_nodes); matmul_operand1 = broadcast_input(input_node1, common_sub_shape, @@ -1004,7 +1008,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, matmul_operand1 = reshape_input_for_matmul(matmul_operand1, common_sub_shape, separate1_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape, is_separate_first1, subgraph_nodes); } @@ -1013,7 +1017,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, matmul_operand2 = reshape_input_for_matmul(matmul_operand2, common_sub_shape, separate2_sub_shape, - reduced_sub_shape_prod, + reduced_sub_shape, is_separate_first2, subgraph_nodes); } From 29b1072bce9df6f14edeb68e5577e40b7f51f636 Mon Sep 17 00:00:00 2001 From: MATEUSZ MIKOLAJCZYK Date: Tue, 21 Jan 2025 16:16:33 +0000 Subject: [PATCH 11/47] Extend Einsum reference test cases Signed-off-by: MATEUSZ MIKOLAJCZYK --- .../tests/functional/op_reference/einsum.cpp | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) diff --git a/src/plugins/template/tests/functional/op_reference/einsum.cpp b/src/plugins/template/tests/functional/op_reference/einsum.cpp index 2d3e7fb627305f..4dd8f46a405472 100644 --- a/src/plugins/template/tests/functional/op_reference/einsum.cpp +++ b/src/plugins/template/tests/functional/op_reference/einsum.cpp @@ -154,6 +154,205 @@ std::vector generateParams() { .equation("abbac,bad->ad") .expectedResult({ET, {2, 1}, std::vector{123, 129}}) .testcaseName("einsum_diagonal_with_matmul"), + + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("...->...") + .expectedResult({ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}) + .testcaseName("einsum_identity"), + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("i...->i") + .expectedResult({ET, {2}, std::vector{6, 15}}) + .testcaseName("einsum_reduce_ellipsis"), + Builder{} + .inputs({{ET, {3, 3, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}}}) + .equation("iii->") + .expectedResult({ET, {}, std::vector{42}}) + .testcaseName("einsum_trace"), + Builder{} + .inputs({{ET, {3, 3, 4}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}}}) + .equation("ii...->") + .expectedResult({ET, {}, std::vector{222}}) + .testcaseName("einsum_trace_ellipsis"), + Builder{} + .inputs({{ET, {3, 2, 1, 2, 1, 3, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}}}) + .equation("ijkjkik->ijk") + .expectedResult({ET, {3, 2, 1}, std::vector{1, 10, 14, 23, 27, 36}}) + .testcaseName("einsum_diagonal_mixed_order"), + Builder{} + .inputs({{ET, + {3, 3, 3, 3, 3}, + std::vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, + 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, + 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, + 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243}}}) + .equation("iiiii->i") + .expectedResult({ET, {3}, std::vector{1, 122, 243}}) + .testcaseName("einsum_5d_diagonal"), + Builder{} + .inputs({{ET, {2, 1}, std::vector{1, 2}}, + {ET, {4, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {3, 1, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9}}}) + .equation("ab,bcd,dbc->ca") + .expectedResult({ET, {3, 2}, std::vector{120, 240, 150, 300, 180, 360}}) + .testcaseName("einsum_3in_broadcast"), + Builder{} + .inputs({{ET, {2, 1}, std::vector{1, 2}}, {ET, {3, 2}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("ab,bc->ac") + .expectedResult({ET, {2, 2}, std::vector{9, 12, 18, 24}}) + .testcaseName("einsum_2in_broadcast_lhs_reduced"), + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}, {ET, {1, 2}, std::vector{1, 2}}}) + .equation("ab,bc->ac") + .expectedResult({ET, {2, 2}, std::vector{6, 12, 15, 30}}) + .testcaseName("einsum_2in_broadcast_rhs_reduced"), + Builder{} + .inputs({{ET, {2, 1}, std::vector{1, 2}}, {ET, {3, 2}, std::vector{1, 2, 3, 4, 5, 6}}}) + .equation("ab,bc->bc") + .expectedResult({ET, {3, 2}, std::vector{3, 6, 9, 12, 15, 18}}) + .testcaseName("einsum_2in_broadcast_lhs_common"), + Builder{} + .inputs({{ET, {2, 3}, std::vector{1, 2, 3, 4, 5, 6}}, {ET, {1, 2}, std::vector{1, 2}}}) + .equation("ab,bc->cb") + .expectedResult({ET, {2, 3}, std::vector{5, 7, 9, 10, 14, 18}}) + .testcaseName("einsum_2in_broadcast_rhs_common"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("aj,j...->a...") + .expectedResult({ET, {1, 4, 2, 1}, std::vector{70, 76, 82, 88, 94, 100, 106, 112}}) + .testcaseName("einsum_2in_only_rhs_out_ellipsis"), + Builder{} + .inputs({{ET, + {2, 7, 4, 3}, + std::vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, + 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j->a...") + .expectedResult( + {ET, {2, 7, 4}, std::vector{14, 32, 50, 68, 86, 104, 122, 140, 158, 176, 194, 212, 230, 248, + 266, 284, 302, 320, 338, 356, 374, 392, 410, 428, 446, 464, 482, 500, + 518, 536, 554, 572, 590, 608, 626, 644, 662, 680, 698, 716, 734, 752, + 770, 788, 806, 824, 842, 860, 878, 896, 914, 932, 950, 968, 986, 1004}}) + .testcaseName("einsum_2in_only_lhs_out_ellipsis"), + Builder{} + .inputs({{ET, + {2, 7, 4, 3}, + std::vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, + 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, + 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, + 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, + 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j->a") + .expectedResult({ET, {2}, std::vector{7196, 21308}}) + .testcaseName("einsum_2in_lhs_ellipsis_out_reduced"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("aj,j...->a") + .expectedResult({ET, {1}, std::vector{728}}) + .testcaseName("einsum_2in_rhs_ellipsis_out_reduced"), + Builder{} + .inputs({{ET, {1, 1, 4, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{8312}}) + .testcaseName("einsum_2in_broadcast_ellipsis_out_reduced"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("a...j,j...->a...") + .expectedResult({ET, {1, 4, 2, 1}, std::vector{70, 76, 82, 88, 94, 100, 106, 112}}) + .testcaseName("einsum_2in_unsqueeze_lhs_ellipsis"), + Builder{} + .inputs({{ET, {1, 1, 4, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a...") + .expectedResult({ET, {1, 1, 4}, std::vector{14, 32, 50, 68}}) + .testcaseName("einsum_2in_unsqueeze_rhs_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, + {ET, {3, 4, 2, 1}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{728}}) + .testcaseName("einsum_2in_unsqueeze_lhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 1, 4, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{164}}) + .testcaseName("einsum_2in_unsqueeze_rhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_lhs_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("aj,j...->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_rhs_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("aj,j->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a...") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_all_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {1}, std::vector{1}}}) + .equation("a...j,j->a") + .expectedResult({ET, {1}, std::vector{6}}) + .testcaseName("einsum_2in_prune_lhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 1}, std::vector{1}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("aj,j...->a") + .expectedResult({ET, {1}, std::vector{6}}) + .testcaseName("einsum_2in_prune_rhs_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) + .equation("a...j,j...->a") + .expectedResult({ET, {1}, std::vector{14}}) + .testcaseName("einsum_2in_prune_inp_ellipsis_no_out_ellipsis") + }; return params; } @@ -161,6 +360,7 @@ std::vector generateParams() { std::vector generateCombinedParams() { const std::vector> generatedParams{ generateParams(), + generateParams(), generateParams(), }; std::vector combinedParams; From 9e749f9bd8690532f72c568c03114b2bb4377670 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 15:50:03 +0100 Subject: [PATCH 12/47] FIx divide by 0 and handling 2+ repeated label types for einsum decomposition Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 48 +++++++++++-------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index f21cfcf74aafc5..26f88a3df71f3e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -16,7 +16,9 @@ #include "openvino/op/einsum.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/maximum.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/power.hpp" #include "openvino/op/range.hpp" #include "openvino/op/reduce_prod.hpp" #include "openvino/op/reduce_sum.hpp" @@ -610,21 +612,25 @@ ov::Output build_identity(const ov::Output& input_node, const auto input_shape = std::make_shared(input_node); const auto repeated_label_indices = ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); + const auto repeated_label_indices_len = + ov::op::v0::Constant::create(ov::element::i64, {}, {repeated_label_dims.size()}); const auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); const auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); const auto reduced_dimension = std::make_shared(repeated_dimensions, const_0, const_0); - const auto reduced_dimension_min_1 = std::make_shared(reduced_dimension, const_1); - - const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); - const auto reduced_size_min_1 = std::make_shared(reduced_size, const_1); - const auto step_size = std::make_shared(reduced_size_min_1, reduced_dimension_min_1); - const auto range = std::make_shared(const_0, reduced_dimension, const_1, ov::element::i64); - const auto steps = std::make_shared(range, step_size); - const auto zeros = std::make_shared(const_0, reduced_size); + const auto range_max_val = std::make_shared(reduced_dimension, repeated_label_indices_len); + const auto step_numerator = std::make_shared(range_max_val, const_1); + const auto step_denominator = std::make_shared(reduced_dimension, const_1); + const auto step_denominator_but_not_0 = std::make_shared(step_denominator, const_1); + const auto step_numerator_but_not_0 = std::make_shared(step_numerator, const_1); + const auto step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); + const auto eye_flattened_indices = std::make_shared(const_0, range_max_val, step); const auto reduced_dimension_1d = std::make_shared(reduced_dimension, const_0); const auto ones = std::make_shared(const_1, reduced_dimension_1d); - const auto eye_flattened = std::make_shared(zeros, steps, ones, const_0); + const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); + const auto zeros = std::make_shared(const_0, reduced_size); + const auto eye_flattened = + std::make_shared(zeros, eye_flattened_indices, ones, const_0); const auto identity_rank = std::make_shared(input_shape); const auto ones_of_input_shape_rank = std::make_shared(const_1, identity_rank); @@ -632,24 +638,28 @@ ov::Output build_identity(const ov::Output& input_node, repeated_label_indices, repeated_dimensions, const_0); + const auto identity = std::make_shared(eye_flattened, identity_shape, false); const auto identity_cvt = std::make_shared(identity, input_node.get_element_type()); subgraph_nodes.insert(subgraph_nodes.end(), {input_shape, repeated_label_indices, + repeated_label_indices_len, const_0, const_1, repeated_dimensions, reduced_dimension, - reduced_dimension_min_1, - reduced_size, - reduced_size_min_1, - step_size, - range, - steps, - zeros, + range_max_val, + step_numerator, + step_denominator, + step_denominator_but_not_0, + step_numerator_but_not_0, + step, + eye_flattened_indices, reduced_dimension_1d, ones, + reduced_size, + zeros, eye_flattened, identity_rank, ones_of_input_shape_rank, @@ -673,12 +683,12 @@ ov::Output build_multi_identity(ov::pass::EinsumDecomposition* einsum_ }; // initially set multi-identity with identity for the first repeated label - const auto multi_identity = get_identity(0); + auto multi_identity = get_identity(0).get_node_shared_ptr(); for (size_t label_ind = 1; label_ind < repeated_labels.size(); ++label_ind) { const auto identity = get_identity(label_ind); - const auto mul = + multi_identity = std::make_shared(multi_identity, identity, ov::op::AutoBroadcastType::NUMPY); - subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + subgraph_nodes.insert(subgraph_nodes.end(), {multi_identity}); } return subgraph_nodes.back(); From 50b6d3ef81700e7b9a21048c78e253c6a72d2b48 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 17:45:54 +0100 Subject: [PATCH 13/47] Move fix_inputs_with_0d_ellipsis to separate function Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 94 +++++++++++-------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 26f88a3df71f3e..0d322c99af813e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1079,6 +1079,61 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, // update a vector of nodes for copy_runtime_info subgraph_nodes.insert(subgraph_nodes.end(), {matmul}); } + +/// \brief Adjusts input subscripts and nodes to handle 0-dimensional ellipsis in Einsum operations. +/// +/// Handle ellipses labels that do not represent any dimensions: +/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript. +/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts. +/// 3. If there is at least one ellipsis that does not represent any dimensions, unsqueeze the corresponding input at +/// ellipsis dimension. +/// +/// \param input_nodes A vector of input nodes for the Einsum operation. +/// \param input_subscripts A vector of input subscripts corresponding to the input nodes. +/// \param output_subscript The output subscript for the Einsum operation. +/// \param subgraph_nodes A vector to store nodes created during the subgraph transformation. +void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, + std::vector& input_subscripts, + std::string& output_subscript, + ov::NodeVector& subgraph_nodes) { + std::vector ellipsis_inputs(input_nodes.size(), false); + std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); + static const std::string ellipsis = "..."; + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); + if (!ellipsis_inputs[inp_iter] || (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { + no_ellipsis_or_empty_inputs[inp_iter] = true; + } + } + if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { + return inp; + })) { + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { + return inp; + })) { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (input_subscripts[inp_iter].find("...") != std::string::npos) { + input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } + } + if (output_subscript.find("...") != std::string::npos) { + output_subscript.erase(output_subscript.find("..."), 3); + } + } else { + for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { + if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { + auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); + auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); + std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; + input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); + } + } + } +} } // namespace ov::pass::EinsumDecomposition::EinsumDecomposition() { @@ -1116,44 +1171,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { auto einsum_path = compute_einsum_path(einsum_node); // fix inputs where ellipsis does not contain any dimensions - std::vector ellipsis_inputs(input_nodes.size(), false); - std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); - static const std::string ellipsis = "..."; - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); - if (!ellipsis_inputs[inp_iter] || - (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { - no_ellipsis_or_empty_inputs[inp_iter] = true; - } - } - if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { - return inp; - })) { - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); - } - } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { - return inp; - })) { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (input_subscripts[inp_iter].find("...") != std::string::npos) { - input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); - } - } - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); - } - } else { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); - std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; - input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); - } - } - } + fix_inputs_with_0d_ellipsis(input_nodes, input_subscripts, output_subscript, subgraph_nodes); // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { From f666700635e2f5836f3461855b98e1ab738dd1a0 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 18:07:58 +0100 Subject: [PATCH 14/47] Modify reshape_input_for_matmul reduced prod to match ne for separate Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 0d322c99af813e..a882c65f0b1dd5 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -422,6 +422,7 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ // compute a product of a sub-shape for separate labels ov::OutputVector separate_parts; + auto reduce_axis_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {0}); if (common_sub_shape.size() > 0 && separate_sub_shape.size() == 0) { // in this case new dimension corresponding to separate labels must be added // since MatMul operation is not possible to do without separate dimensions if the @@ -432,7 +433,6 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ } else if (separate_sub_shape.size() > 0) { // in this case compute a product of separate dimension sizes since they must be // presented with just one dimension for MatMul - auto reduce_axis_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {0}); auto separate_shape_prod = std::make_shared(separate_sub_shape[0], reduce_axis_const, true); separate_parts.push_back(separate_shape_prod->output(0)); @@ -440,9 +440,9 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ } ov::OutputVector reduced_sub_shape_prod; auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); - for (auto sub_shape : reduced_sub_shape) { - auto product = std::make_shared(sub_shape, const_0, true); - subgraph_nodes.insert(subgraph_nodes.end(), {const_0, product}); + if (reduced_sub_shape.size() > 0) { + auto product = std::make_shared(reduced_sub_shape[0], const_0, true); + subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, product}); reduced_sub_shape_prod.push_back(product->output(0)); } From 6347ed2d072f6b5eef71e2583debdb0a8caeede7 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 23 Jan 2025 19:05:51 +0100 Subject: [PATCH 15/47] Refactor empty ellipsis handling in Einsum decomposition to improve clarity Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index a882c65f0b1dd5..b58b58b57ae5d2 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1085,8 +1085,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, /// Handle ellipses labels that do not represent any dimensions: /// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript. /// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts. -/// 3. If there is at least one ellipsis that does not represent any dimensions, unsqueeze the corresponding input at -/// ellipsis dimension. +/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any, /// /// \param input_nodes A vector of input nodes for the Einsum operation. /// \param input_subscripts A vector of input subscripts corresponding to the input nodes. @@ -1096,40 +1095,43 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, std::vector& input_subscripts, std::string& output_subscript, ov::NodeVector& subgraph_nodes) { - std::vector ellipsis_inputs(input_nodes.size(), false); - std::vector no_ellipsis_or_empty_inputs(input_nodes.size(), false); static const std::string ellipsis = "..."; - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); - if (!ellipsis_inputs[inp_iter] || (input_nodes[inp_iter].get_partial_shape().rank() == (labels.size() - 1))) { - no_ellipsis_or_empty_inputs[inp_iter] = true; - } + bool has_ellipsis = false; + bool all_no_ellipsis_or_empty = true; + + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end(); + has_ellipsis |= has_ellipsis_in_input; + all_no_ellipsis_or_empty &= + !has_ellipsis_in_input || (input_nodes[i].get_partial_shape().rank().get_length() == + static_cast(labels.size() - 1)); } - if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { - return inp; - })) { - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + + if (!has_ellipsis) { + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } - } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { - return inp; - })) { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (input_subscripts[inp_iter].find("...") != std::string::npos) { - input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } else if (all_no_ellipsis_or_empty) { + for (auto& subscript : input_subscripts) { + if (subscript.find(ellipsis) != std::string::npos) { + subscript.erase(subscript.find(ellipsis), ellipsis.size()); } } - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } } else { - for (size_t inp_iter = 0; inp_iter < input_nodes.size(); inp_iter++) { - if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); - std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; - input_nodes[inp_iter] = unsqueeze_input(input_nodes[inp_iter], ellipsis_idx, subgraph_nodes); + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() && + input_nodes[i].get_partial_shape().rank().get_length() == + static_cast(labels.size() - 1)) { + input_nodes[i] = unsqueeze_input( + input_nodes[i], + {static_cast( + std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis)))}, + subgraph_nodes); } } } From d380155756cabcab19084663d54f9d4fde8c23de Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Fri, 24 Jan 2025 12:25:27 +0100 Subject: [PATCH 16/47] Refactor handling of 0-dimensional ellipsis in Einsum operations for improved clarity Signed-off-by: Mateusz Mikolajczyk --- src/core/reference/src/op/einsum.cpp | 93 ++++++++++++++++------------ 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index 4457b628f670be..ec0be118fc13e1 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -927,55 +927,72 @@ void contract_two_inputs(ov::TensorVector& inputs, update_operands(inputs, input_subscripts, input_ind1, input_ind2, contract_output, resultant_subscript); } +/// \brief Adjusts input subscripts and nodes to handle 0-dimensional ellipsis in Einsum operations. +/// +/// Handle ellipses labels that do not represent any dimensions: +/// 1. If there is no ellipsis in the input subscripts, remove ellipsis from the output subscript. +/// 2. If all ellipses in the input subscripts do not represent any dimensions, remove ellipses from all subscripts. +/// 3. If there is at least one ellipsis that represents dimension, unsqueeze ellipses that do not represent any, +/// +/// \param input_nodes A vector of input tensors for the Einsum operation. +/// \param input_subscripts A vector of input subscripts corresponding to the input nodes. +/// \param output_subscript The output subscript for the Einsum operation. template -void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) { - std::vector input_subscripts; - std::string output_subscript; - ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript); - - // compute einsum path that is used to contract a pair of operands - // in more optimal order - size_t num_inputs = inputs.size(); - auto einsum_path = compute_einsum_path(num_inputs); - - ov::TensorVector int_inputs = inputs; - std::vector ellipsis_inputs(inputs.size(), false); - std::vector no_ellipsis_or_empty_inputs(inputs.size(), false); +void fix_inputs_with_0d_ellipsis(ov::TensorVector& input_nodes, + std::vector& input_subscripts, + std::string& output_subscript) { static const std::string ellipsis = "..."; - for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { - const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - ellipsis_inputs[inp_iter] = (std::find(labels.begin(), labels.end(), "...") != labels.end()); - if (!ellipsis_inputs[inp_iter] || (inputs[inp_iter].get_shape().size() == (labels.size() - 1))) { - no_ellipsis_or_empty_inputs[inp_iter] = true; - } + bool has_ellipsis = false; + bool all_no_ellipsis_or_empty = true; + + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end(); + has_ellipsis |= has_ellipsis_in_input; + all_no_ellipsis_or_empty &= + !has_ellipsis_in_input || (input_nodes[i].get_shape().size() == (labels.size() - 1)); } - if (std::none_of(ellipsis_inputs.begin(), ellipsis_inputs.end(), [](bool inp) { - return inp; - })) { - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + + if (!has_ellipsis) { + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } - } else if (std::all_of(no_ellipsis_or_empty_inputs.begin(), no_ellipsis_or_empty_inputs.end(), [](bool inp) { - return inp; - })) { - for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { - if (input_subscripts[inp_iter].find("...") != std::string::npos) { - input_subscripts[inp_iter].erase(input_subscripts[inp_iter].find("..."), 3); + } else if (all_no_ellipsis_or_empty) { + for (auto& subscript : input_subscripts) { + if (subscript.find(ellipsis) != std::string::npos) { + subscript.erase(subscript.find(ellipsis), ellipsis.size()); } } - if (output_subscript.find("...") != std::string::npos) { - output_subscript.erase(output_subscript.find("..."), 3); + if (output_subscript.find(ellipsis) != std::string::npos) { + output_subscript.erase(output_subscript.find(ellipsis), ellipsis.size()); } } else { - for (size_t inp_iter = 0; inp_iter < inputs.size(); inp_iter++) { - if (ellipsis_inputs[inp_iter] && no_ellipsis_or_empty_inputs[inp_iter]) { - auto labels = ov::op::v7::Einsum::extract_labels(input_subscripts[inp_iter]); - auto ellipsis_idx_iter = std::find(labels.begin(), labels.end(), "..."); - std::vector ellipsis_idx{std::distance(labels.begin(), ellipsis_idx_iter)}; - int_inputs[inp_iter] = unsqueeze_input(inputs[inp_iter], ellipsis_idx); + for (size_t i = 0; i < input_nodes.size(); ++i) { + const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); + if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() && + input_nodes[i].get_shape().size() == (labels.size() - 1)) { + std::vector ellipsis_idx{ + std::distance(labels.begin(), std::find(labels.begin(), labels.end(), ellipsis))}; + input_nodes[i] = unsqueeze_input(input_nodes[i], ellipsis_idx); } } } +} + +template +void einsum_impl(const ov::TensorVector& inputs, ov::TensorVector& outputs, const std::string& equation) { + std::vector input_subscripts; + std::string output_subscript; + ov::op::v7::Einsum::parse_equation(equation, input_subscripts, output_subscript); + + // compute einsum path that is used to contract a pair of operands + // in more optimal order + size_t num_inputs = inputs.size(); + auto einsum_path = compute_einsum_path(num_inputs); + ov::TensorVector int_inputs = inputs; + + // fix inputs where ellipsis does not contain any dimensions + fix_inputs_with_0d_ellipsis(int_inputs, input_subscripts, output_subscript); // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { From 191899102de2eae9d1560999ea5b9d43bcd45f9e Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Fri, 24 Jan 2025 19:05:56 +0100 Subject: [PATCH 17/47] Refactor broadcast_merge_shapes to eliminate loop Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index b58b58b57ae5d2..35cc48253b0e1e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -322,22 +322,24 @@ ov::Output unsqueeze_input(const ov::Output& input_node, ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, ov::OutputVector& shapes_rhs, ov::NodeVector& subgraph_nodes) { - // TODO - Refactor func to remove loop and duplicated Broadcast. - OPENVINO_ASSERT(shapes_lhs.size() == shapes_rhs.size()); - ov::OutputVector broadcasted_shape_nodes{shapes_lhs.size()}; - - for (size_t shp_i = 0; shp_i < shapes_lhs.size(); shp_i++) { + ov::OutputVector broadcasted_shape_nodes{}; + // OutputVector is either empty or contains a single shape + if (shapes_lhs.size() == 1 && shapes_rhs.size() == 1) { auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1}); auto tmp_const_of_lhs_shp = - std::make_shared(const_1, shapes_lhs[shp_i], ov::op::BroadcastType::NUMPY); + std::make_shared(const_1, shapes_lhs[0], ov::op::BroadcastType::NUMPY); auto tmp_const_of_broadcasted_shp = std::make_shared(tmp_const_of_lhs_shp, - shapes_rhs[shp_i], + shapes_rhs[0], ov::op::BroadcastType::BIDIRECTIONAL); auto broadcasted_shape = std::make_shared(tmp_const_of_broadcasted_shp); - broadcasted_shape_nodes[shp_i] = broadcasted_shape; + broadcasted_shape_nodes.push_back(broadcasted_shape->output(0)); subgraph_nodes.insert(subgraph_nodes.end(), {const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape}); + } else if (shapes_lhs.size() == 0 && shapes_rhs.size() == 1) { + return shapes_rhs; + } else if (shapes_lhs.size() == 1 && shapes_rhs.size() == 0) { + return shapes_lhs; } return broadcasted_shape_nodes; } From 2eee35cd98f0bd78fd7dfc289bb254396380da28 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 30 Jan 2025 11:21:19 +0100 Subject: [PATCH 18/47] Fix shape_infer for reduced out ellipsis with dynamic rank inputs Signed-off-by: Mateusz Mikolajczyk --- .../include/einsum_shape_inference.hpp | 13 ++-- src/core/tests/type_prop/einsum.cpp | 60 +++++++++++++++++++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/core/shape_inference/include/einsum_shape_inference.hpp b/src/core/shape_inference/include/einsum_shape_inference.hpp index 1ee471117d6872..007b220ddda3d6 100644 --- a/src/core/shape_inference/include/einsum_shape_inference.hpp +++ b/src/core/shape_inference/include/einsum_shape_inference.hpp @@ -24,10 +24,13 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s input_subscripts.size() == input_shapes.size(), "Equation must contain a number of subscripts equal to a number of Einsum inputs."); + const auto output_labels = Einsum::extract_labels(output_subscript); + const auto has_out_ellipsis = std::any_of(output_labels.begin(), output_labels.end(), [](std::string label) { + return label == "..."; + }); // create a dictionary with dimension sizes (or ranges in case of dynamic shapes) for each label // and check their compatibility in case of repeating labels std::unordered_map label_to_shape; - for (size_t input_idx = 0; input_idx < input_shapes.size(); ++input_idx) { const auto& pshape = input_shapes[input_idx]; const auto labels = Einsum::extract_labels(input_subscripts[input_idx]); @@ -78,16 +81,11 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s } } } else { - if (has_ellipsis) { + if (has_ellipsis && has_out_ellipsis) { // Shape has dynamic rank and ellipsis return {pshape}; } for (auto const& label : labels) { - NODE_VALIDATION_CHECK(op, - label != "...", - "The subscript corresponding to a dynamic rank input must " - "not contain ellipsis."); - if (label_to_shape.find(label) == label_to_shape.end()) { label_to_shape[label] = ov::PartialShape{Dimension::dynamic()}; } @@ -96,7 +94,6 @@ std::vector shape_infer(const Einsum* op, const std::vector& input_s } // compute the output shape - const auto output_labels = Einsum::extract_labels(output_subscript); auto output_shapes = std::vector(1); auto& output_shape = output_shapes[0]; diff --git a/src/core/tests/type_prop/einsum.cpp b/src/core/tests/type_prop/einsum.cpp index 4772393a89f497..5c96289af647ea 100644 --- a/src/core/tests/type_prop/einsum.cpp +++ b/src/core/tests/type_prop/einsum.cpp @@ -478,6 +478,66 @@ TEST_F(TypePropEinsumTest, all_dynamic_rank_ellipsis) { EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); } +TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis) { + const std::string equation = "a...b,b...->...a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape::dynamic()); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, lhs_dynamic_rank_ellipsis_reduced_out_ellipsis) { + const std::string equation = "a...b,b...->a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{PartialShape::dynamic(), {3, 11, 7, 4}}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({ov::Dimension::dynamic()})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + +TEST_F(TypePropEinsumTest, rhs_dynamic_rank_ellipsis_reduced_out_ellipsis) { + const std::string equation = "a...b,b...->a"; + constexpr auto et = element::i32; + + auto input_shapes = PartialShapes{{3, 11, 7, 4}, PartialShape::dynamic()}; + const auto inputs = make_inputs(et, input_shapes); + const auto o = make_op(inputs, equation); + + EXPECT_EQ(o->get_equation(), equation); + EXPECT_EQ(o->get_element_type(), et); + EXPECT_EQ(o->get_output_size(), exp_einsum_outputs_count); + EXPECT_EQ(o->get_output_partial_shape(0), PartialShape({3})); + EXPECT_THAT(get_shape_symbols(o->get_output_partial_shape(0)), Each(nullptr)); +} + TEST_F(TypePropEinsumTest, broadcasting_same_symbol_common) { const std::string equation = "ab,ba->b"; constexpr auto et = element::i32; From 6d53d09c564b83cb03c9b21bf00ac5cf943b295b Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 6 Feb 2025 17:58:30 +0100 Subject: [PATCH 19/47] Implement unsqueeze_ellipses_to_same_rank function for consistent ellipsis rank handling for broadcasting Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 50 ++++++++++++++++- src/core/reference/src/op/einsum.cpp | 47 +++++++++++++++- .../tests/functional/op_reference/einsum.cpp | 53 ++++++++++++++++++- 3 files changed, 147 insertions(+), 3 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 35cc48253b0e1e..dd3cb721b27161 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -776,6 +776,50 @@ void extract_diagonal(ov::pass::EinsumDecomposition* einsum_decompose_ptr, input_subscripts[input_ind] = resultant_subscript; } +/// \brief Adjusts the ranks of two input tensors by unsqueezing ellipses to the same rank. +/// +/// This function ensures that the ellipses in the input subscripts of the two tensors have the same rank by unsqueezing +/// the necessary dimensions. It modifies the inputs in place. +/// +/// \param inputs A vector of input tensors. +/// \param input_subscripts A vector of input subscripts corresponding to the input tensors. +/// \param input_ind1 The index of the first input tensor in the inputs vector. +/// \param input_ind2 The index of the second input tensor in the inputs vector. +/// \param subgraph_nodes A vector of operation nodes that is included into +/// a sub-graph decomposing Einsum that is needed for copy_runtime_info +void unsqueeze_ellipses_to_same_rank(ov::OutputVector& inputs, + std::vector& input_subscripts, + size_t input_ind1, + size_t input_ind2, + ov::NodeVector& subgraph_nodes) { + constexpr char ellipsis[] = "..."; + const auto& input1 = inputs[input_ind1]; + const auto& input2 = inputs[input_ind2]; + OPENVINO_ASSERT(input1.get_partial_shape().is_static() && input2.get_partial_shape().is_static()); + auto label_to_dim_map1 = compute_label_dim_map(input1.get_partial_shape().size(), input_subscripts[input_ind1]); + auto label_to_dim_map2 = compute_label_dim_map(input2.get_partial_shape().size(), input_subscripts[input_ind2]); + if (label_to_dim_map1.find(ellipsis) != label_to_dim_map1.end() && + label_to_dim_map2.find(ellipsis) != label_to_dim_map2.end()) { + std::vector unsqueeze_axis1, unsqueeze_axis2; + const auto& ellipsis_dims1 = label_to_dim_map1[ellipsis]; + const auto& ellipsis_dims2 = label_to_dim_map2[ellipsis]; + if (ellipsis_dims2.size() > ellipsis_dims1.size()) { + for (size_t i = 0; i < ellipsis_dims2.size() - ellipsis_dims1.size(); ++i) { + unsqueeze_axis1.push_back(ellipsis_dims1[0] + i); + } + } else if (ellipsis_dims1.size() > ellipsis_dims2.size()) { + for (size_t i = 0; i < ellipsis_dims1.size() - ellipsis_dims2.size(); ++i) { + unsqueeze_axis2.push_back(ellipsis_dims2[0] + i); + } + } + ov::Output unsqueeze_output1 = unsqueeze_input(input1, unsqueeze_axis1, subgraph_nodes); + ov::Output unsqueeze_output2 = unsqueeze_input(input2, unsqueeze_axis2, subgraph_nodes); + inputs[input_ind1] = unsqueeze_output1; + inputs[input_ind2] = unsqueeze_output2; + return; + } +} + /// \brief Contract two inputs of Einsum operation according to equation. /// The result of the contraction is appended into input_nodes along with its subscript. /// The input nodes for these two operands are removed from input_nodes along with their input @@ -810,6 +854,9 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, const auto& input_node1 = input_nodes[input_ind1]; const auto& input_node2 = input_nodes[input_ind2]; + // unsqueeze inputs to have same rank of ellipsis for correct broadcasting + unsqueeze_ellipses_to_same_rank(input_nodes, input_subscripts, input_ind1, input_ind2, subgraph_nodes); + // extract diagonals in case repeated labels in the corresponding input subscripts extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); @@ -882,7 +929,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, // unsqueeze the first operand with new dimensions in the tail // and the number of them is equal to the number of separate labels in the second // subscript - int64_t unsqueeze_dim = labels1.size(); + + int64_t unsqueeze_dim = input_node1.get_partial_shape().size(); std::vector unsqueeze_axis1; for (size_t label_ind = 0; label_ind < separate_labels_inds2.size(); ++label_ind) { unsqueeze_axis1.push_back(unsqueeze_dim++); diff --git a/src/core/reference/src/op/einsum.cpp b/src/core/reference/src/op/einsum.cpp index ec0be118fc13e1..b0ea9ab663b790 100644 --- a/src/core/reference/src/op/einsum.cpp +++ b/src/core/reference/src/op/einsum.cpp @@ -413,7 +413,7 @@ void transpose_input(ov::TensorVector& inputs, } /// \brief Broadcast input to a new shape. The MatMul operation requires the -/// same shape of both operands in the common (or batch) dimensionsy. +/// same shape of both operands in the common (or batch) dimensions. /// template void broadcast_input(ov::TensorVector& inputs, @@ -650,6 +650,48 @@ ov::Tensor reshape_input_for_matmul(const ov::Tensor& input, return output; } +/// \brief Adjusts the rank of two input tensors by unsqueezing ellipses to the same rank. +/// +/// This function takes two input tensors and their corresponding subscripts, and ensures that +/// the ellipses ("...") in the subscripts have the same rank by unsqueezing dimensions as needed. +/// It modifies the input tensors in place. +/// +/// \param inputs A vector of input tensors. +/// \param input_subscripts A vector of strings representing the subscripts for each input tensor. +/// \param input_ind1 The index of the first input tensor in the inputs vector. +/// \param input_ind2 The index of the second input tensor in the inputs vector. +template +void unsqueeze_ellipses_to_same_rank(ov::TensorVector& inputs, + std::vector& input_subscripts, + size_t input_ind1, + size_t input_ind2) { + constexpr char ellipsis[] = "..."; + const auto& input1 = inputs[input_ind1]; + const auto& input2 = inputs[input_ind2]; + auto label_to_dim_map1 = compute_label_dim_map(input1.get_shape().size(), input_subscripts[input_ind1]); + auto label_to_dim_map2 = compute_label_dim_map(input2.get_shape().size(), input_subscripts[input_ind2]); + if (label_to_dim_map1.find(ellipsis) != label_to_dim_map1.end() && + label_to_dim_map2.find(ellipsis) != label_to_dim_map2.end()) { + std::vector unsqueeze_axis1, unsqueeze_axis2; + const auto& ellipsis_dims1 = label_to_dim_map1[ellipsis]; + const auto& ellipsis_dims2 = label_to_dim_map2[ellipsis]; + if (ellipsis_dims2.size() > ellipsis_dims1.size()) { + for (size_t i = 0; i < ellipsis_dims2.size() - ellipsis_dims1.size(); ++i) { + unsqueeze_axis1.push_back(ellipsis_dims1[0] + i); + } + } else if (ellipsis_dims1.size() > ellipsis_dims2.size()) { + for (size_t i = 0; i < ellipsis_dims1.size() - ellipsis_dims2.size(); ++i) { + unsqueeze_axis2.push_back(ellipsis_dims2[0] + i); + } + } + ov::Tensor unsqueeze_output1 = unsqueeze_input(input1, unsqueeze_axis1); + ov::Tensor unsqueeze_output2 = unsqueeze_input(input2, unsqueeze_axis2); + inputs[input_ind1] = std::move(unsqueeze_output1); + inputs[input_ind2] = std::move(unsqueeze_output2); + return; + } +} + /// \brief Contract two inputs of Einsum operation according to equation. /// The result of the contraction is appended into inputs along with its /// subscript. The inputs with indices input_ind1 and input_ind2 are removed from @@ -675,6 +717,9 @@ void contract_two_inputs(ov::TensorVector& inputs, const auto& input1 = inputs[input_ind1]; const auto& input2 = inputs[input_ind2]; + // unsqueeze inputs to have same rank of ellipsis for correct broadcasting + unsqueeze_ellipses_to_same_rank(inputs, input_subscripts, input_ind1, input_ind2); + // extract diagonals in case repeated labels in the corresponding input // subscripts extract_diagonal(inputs, input_subscripts, input_ind1); diff --git a/src/plugins/template/tests/functional/op_reference/einsum.cpp b/src/plugins/template/tests/functional/op_reference/einsum.cpp index 4dd8f46a405472..853cef8f590f73 100644 --- a/src/plugins/template/tests/functional/op_reference/einsum.cpp +++ b/src/plugins/template/tests/functional/op_reference/einsum.cpp @@ -351,7 +351,58 @@ std::vector generateParams() { .inputs({{ET, {1, 3}, std::vector{1, 2, 3}}, {ET, {3}, std::vector{1, 2, 3}}}) .equation("a...j,j...->a") .expectedResult({ET, {1}, std::vector{14}}) - .testcaseName("einsum_2in_prune_inp_ellipsis_no_out_ellipsis") + .testcaseName("einsum_2in_prune_inp_ellipsis_no_out_ellipsis"), + Builder{} + .inputs({{ET, {2, 2, 1}, std::vector{1, 2, 3, 4}}, + {ET, {4, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, + {1, 1, 2, 3, 1, 3}, + std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}}}) + .equation("a...b,bcd,...dbc->c...a") + .expectedResult( + {ET, {3, 1, 1, 2, 2}, std::vector{120, 360, 780, 1560, 150, 450, 840, 1680, 180, 540, 900, 1800}}) + .testcaseName("einsum_3in_broadcast_duplicated_ellipsis"), + + Builder{} + .inputs({{ET, {2, 2, 1}, std::vector{1, 2, 3, 4}}, + {ET, {4, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {1, 2, 3, 1, 3, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}}}) + .equation("a...b,bcd,...dbcc->c...a") + .expectedResult( + {ET, {3, 1, 2, 2}, std::vector{300, 900, 2220, 4440, 420, 1260, 2460, 4920, 540, 1620, 2700, 5400}}) + .testcaseName("einsum_3in_broadcast_duplicated_ellipsis_repeated_1"), + Builder{} + .inputs({{ET, {2, 2, 1, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {4, 1, 1, 1, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {3, 1, 3, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}}}) + .equation("a...b,bcccdd,...dbcc->cb...") + .expectedResult( + {ET, {3, 4, 2, 1, 1}, std::vector{120, 180, 240, 360, 360, 540, 480, 720, 168, 252, 336, 504, + 504, 756, 672, 1008, 216, 324, 432, 648, 648, 972, 864, 1296}}) + .testcaseName("einsum_3in_broadcast_duplicated_ellipsis_repeated_1"), + Builder{} + .inputs({{ET, {2, 2, 1}, std::vector{1, 2, 3, 4}}, + {ET, {4, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {1, 2, 3, 1, 3, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}}}) + .equation("a...b,bcd,...dbcc->ca") + .expectedResult({ET, {3, 2}, std::vector{2520, 5340, 2880, 6180, 3240, 7020}}) + .testcaseName("einsum_3in_broadcast_duplicated_ellipsis_repeated_3"), + Builder{} + .inputs({{ET, {2, 2, 1, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {4, 1, 1, 1, 1, 1}, std::vector{1, 2, 3, 4}}, + {ET, {3, 1, 3, 3}, std::vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27}}}) + .equation("a...b,bcccdd,...dbcc->cb") + .expectedResult( + {ET, {3, 4}, std::vector{300, 600, 900, 1200, 420, 840, 1260, 1680, 540, 1080, 1620, 2160}}) + .testcaseName("einsum_3in_broadcast_duplicated_ellipsis_repeated_4") }; return params; From a73ca6fa4e2cd394c3579856678fd5702b407b12 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 6 Feb 2025 19:02:27 +0100 Subject: [PATCH 20/47] Implement requested changes to increase clarity Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index dd3cb721b27161..27a345161c3bfd 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -884,7 +884,7 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, std::vector common_labels_inds1, common_labels_inds2; std::vector separate_labels_inds1, separate_labels_inds2; std::vector reduced_labels_inds1, reduced_labels_inds2; - std::vector common_labels, sep_labels1, sep_labels2, reduced_labels; // +++++ + std::vector common_labels, sep_labels1, sep_labels2, reduced_labels; for (size_t label_ind = 0; label_ind < labels1.size(); ++label_ind) { const auto& label = labels1[label_ind]; auto iter = std::find(labels2.begin(), labels2.end(), label); @@ -1154,8 +1154,7 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, bool has_ellipsis_in_input = std::find(labels.begin(), labels.end(), ellipsis) != labels.end(); has_ellipsis |= has_ellipsis_in_input; all_no_ellipsis_or_empty &= - !has_ellipsis_in_input || (input_nodes[i].get_partial_shape().rank().get_length() == - static_cast(labels.size() - 1)); + !has_ellipsis_in_input || (input_nodes[i].get_partial_shape().size() + 1 == labels.size()); } if (!has_ellipsis) { @@ -1175,8 +1174,7 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, for (size_t i = 0; i < input_nodes.size(); ++i) { const auto& labels = ov::op::v7::Einsum::extract_labels(input_subscripts[i]); if (std::find(labels.begin(), labels.end(), ellipsis) != labels.end() && - input_nodes[i].get_partial_shape().rank().get_length() == - static_cast(labels.size() - 1)) { + input_nodes[i].get_partial_shape().size() + 1 == labels.size()) { input_nodes[i] = unsqueeze_input( input_nodes[i], {static_cast( From 9b1a08e44f6c9435d960a9e9a0b4f45839ddcec2 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Fri, 7 Feb 2025 14:53:24 +0100 Subject: [PATCH 21/47] FIx assert in unsqueeze_ellipses_to_same_rank Signed-off-by: Mateusz Mikolajczyk --- .../src/transformations/op_conversions/einsum_decomposition.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 27a345161c3bfd..a6f433869ca2ee 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -795,7 +795,7 @@ void unsqueeze_ellipses_to_same_rank(ov::OutputVector& inputs, constexpr char ellipsis[] = "..."; const auto& input1 = inputs[input_ind1]; const auto& input2 = inputs[input_ind2]; - OPENVINO_ASSERT(input1.get_partial_shape().is_static() && input2.get_partial_shape().is_static()); + OPENVINO_ASSERT(input1.get_partial_shape().rank().is_static() && input2.get_partial_shape().rank().is_static()); auto label_to_dim_map1 = compute_label_dim_map(input1.get_partial_shape().size(), input_subscripts[input_ind1]); auto label_to_dim_map2 = compute_label_dim_map(input2.get_partial_shape().size(), input_subscripts[input_ind2]); if (label_to_dim_map1.find(ellipsis) != label_to_dim_map1.end() && From ab924267e5c13553971906b08e73a2ed71e93376 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Mon, 10 Feb 2025 15:00:48 +0100 Subject: [PATCH 22/47] Add einsum decomposition test cases + minor decomposition improvements Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 16 +- .../einsum_decomposition_test.cpp | 1011 +++++++++++++++++ 2 files changed, 1019 insertions(+), 8 deletions(-) create mode 100644 src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index a6f433869ca2ee..c3de50ac32fb14 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -618,17 +618,17 @@ ov::Output build_identity(const ov::Output& input_node, ov::op::v0::Constant::create(ov::element::i64, {}, {repeated_label_dims.size()}); const auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); const auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); - const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); - const auto reduced_dimension = std::make_shared(repeated_dimensions, const_0, const_0); - const auto range_max_val = std::make_shared(reduced_dimension, repeated_label_indices_len); + const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); + const auto repeated_dimension = std::make_shared(repeated_dimensions, const_0, const_0); + const auto range_max_val = std::make_shared(repeated_dimension, repeated_label_indices_len); const auto step_numerator = std::make_shared(range_max_val, const_1); - const auto step_denominator = std::make_shared(reduced_dimension, const_1); + const auto step_denominator = std::make_shared(repeated_dimension, const_1); const auto step_denominator_but_not_0 = std::make_shared(step_denominator, const_1); const auto step_numerator_but_not_0 = std::make_shared(step_numerator, const_1); const auto step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); const auto eye_flattened_indices = std::make_shared(const_0, range_max_val, step); - const auto reduced_dimension_1d = std::make_shared(reduced_dimension, const_0); - const auto ones = std::make_shared(const_1, reduced_dimension_1d); + const auto repeated_dimension_1d = std::make_shared(repeated_dimension, const_0); + const auto ones = std::make_shared(const_1, repeated_dimension_1d); const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); const auto zeros = std::make_shared(const_0, reduced_size); const auto eye_flattened = @@ -650,7 +650,7 @@ ov::Output build_identity(const ov::Output& input_node, const_0, const_1, repeated_dimensions, - reduced_dimension, + repeated_dimension, range_max_val, step_numerator, step_denominator, @@ -658,7 +658,7 @@ ov::Output build_identity(const ov::Output& input_node, step_numerator_but_not_0, step, eye_flattened_indices, - reduced_dimension_1d, + repeated_dimension_1d, ones, reduced_size, zeros, diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp new file mode 100644 index 00000000000000..4d718a4c96557f --- /dev/null +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -0,0 +1,1011 @@ +// // Copyright (C) 2018-2025 Intel Corporation +// // SPDX-License-Identifier: Apache-2.0 +// // + +#include "transformations/op_conversions/einsum_decomposition.hpp" + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/opsets/opset7.hpp" +#include "openvino/pass/constant_folding.hpp" +#include "transformations/utils/gen_pattern.hpp" + +using namespace ov; + +TEST_F(TransformationTestsF, Einsum_2in_matmul) { + PartialShape data_shape_1{5, 2}; + PartialShape data_shape_2{10, 2, 25}; + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto einsum = std::make_shared(OutputVector{data_1, data_2}, "kl,mlj->mkj"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1, data_2}); + manager.register_pass(); + } + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto order_2 = ov::op::v0::Constant::create(element::i64, {3}, {0, 2, 1}); + auto transpose_2 = std::make_shared(data_2, order_2); + + auto broadcast_shape_constant_1 = + ov::op::v0::Constant::create(element::i64, Shape{data_shape_1.size()}, {5, 2}); + auto broadcast_shape_constant_2 = + ov::op::v0::Constant::create(element::i64, Shape{data_shape_2.size()}, {10, 25, 2}); + auto broadcast_1 = std::make_shared(data_1, + broadcast_shape_constant_1, + ov::op::BroadcastType::BIDIRECTIONAL); + auto broadcast_2 = std::make_shared(transpose_2, + broadcast_shape_constant_2, + ov::op::BroadcastType::BIDIRECTIONAL); + auto shape_constant_1 = ov::op::v0::Constant::create(element::i64, Shape{2}, {5, 2}); + auto shape_constant_2 = ov::op::v0::Constant::create(element::i64, Shape{2}, {250, 2}); + auto reshape_1 = std::make_shared(broadcast_1, shape_constant_1, false); + auto reshape_2 = std::make_shared(broadcast_2, shape_constant_2, false); + auto matmul = std::make_shared(reshape_1, reshape_2, false, true); + auto shape_out = ov::op::v0::Constant::create(element::i64, {3}, {5, 10, 25}); + auto reshape_out = std::make_shared(matmul, shape_out, false); + auto order_out = ov::op::v0::Constant::create(element::i64, {3}, {1, 0, 2}); + auto transpose_out = std::make_shared(reshape_out, order_out); + + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1, data_2}); + } +} + +TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { + PartialShape data_shape_1 = PartialShape::dynamic(2); + PartialShape data_shape_2 = PartialShape::dynamic(3); + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto einsum = std::make_shared(OutputVector{data_1, data_2}, "kl,mlj->mkj"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1, data_2}); + manager.register_pass(); + } + { + using namespace ov::gen_pattern; + auto node_2 = std::make_shared(element::f32, data_shape_1); + auto node_0 = std::make_shared(element::f32, data_shape_2); + auto ShapeOf_487 = makeOP({node_2}, {{"output_type", "i64"}}); + auto Constant_507 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Constant_508 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_510 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_511 = makeOP({ShapeOf_487, Constant_507, Constant_508, Constant_510}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_499 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_489 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_490 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_492 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_493 = makeOP({ShapeOf_487, Constant_489, Constant_490, Constant_492}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_500 = makeOP({Constant_499, StridedSlice_493}, {{"mode", "numpy"}}); + auto Constant_485 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {0, 2, 1}); + auto Transpose_486 = makeOP({node_0, Constant_485}); + auto ShapeOf_488 = makeOP({Transpose_486}, {{"output_type", "i64"}}); + auto Constant_494 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_495 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {3}); + auto Constant_497 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_498 = makeOP({ShapeOf_488, Constant_494, Constant_495, Constant_497}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_503 = makeOP({Broadcast_500, StridedSlice_498}, {{"mode", "bidirectional"}}); + auto ShapeOf_506 = makeOP({Broadcast_503}, {{"output_type", "i64"}}); + auto Concat_512 = makeOP({StridedSlice_511, ShapeOf_506}, {{"axis", 0}}); + auto Broadcast_513 = makeOP({node_2, Concat_512}, {{"mode", "bidirectional"}}); + auto Constant_525 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto ReduceProd_526 = makeOP({StridedSlice_511, Constant_525}, {{"keep_dims", true}}); + auto ReduceProd_528 = makeOP({ShapeOf_506, {0}}, {{"keep_dims", true}}); + auto Concat_529 = makeOP({ReduceProd_526, ReduceProd_528}, {{"axis", 0}}); + auto Reshape_530 = makeOP({Broadcast_513, Concat_529}, {{"special_zero", false}}); + auto Constant_516 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Constant_517 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_519 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_520 = makeOP({ShapeOf_488, Constant_516, Constant_517, Constant_519}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Concat_521 = makeOP({StridedSlice_520, ShapeOf_506}, {{"axis", 0}}); + auto Broadcast_522 = makeOP({Transpose_486, Concat_521}, {{"mode", "bidirectional"}}); + auto Constant_569 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto ReduceProd_570 = makeOP({StridedSlice_520, Constant_569}, {{"keep_dims", true}}); + auto ReduceProd_572 = makeOP({ShapeOf_506, {0}}, {{"keep_dims", true}}); + auto Concat_573 = makeOP({ReduceProd_570, ReduceProd_572}, {{"axis", 0}}); + auto Reshape_574 = makeOP({Broadcast_522, Concat_573}, {{"special_zero", false}}); + auto matmul = std::make_shared(Reshape_530, Reshape_574, false, true); + auto shape_out = makeOP({StridedSlice_511, StridedSlice_520}, {{"axis", 0}}); + auto reshape_out = std::make_shared(matmul, shape_out, false); + auto order_out = ov::op::v0::Constant::create(element::i64, {3}, {1, 0, 2}); + auto transpose_out = std::make_shared(reshape_out, order_out); + + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{node_2, node_0}); + } +} + +TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { + PartialShape data_shape_1 = PartialShape::dynamic(2); + PartialShape data_shape_2 = PartialShape::dynamic(5); + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto einsum = std::make_shared(OutputVector{data_1, data_2}, "kl...,m...lj->mkj"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1, data_2}); + manager.register_pass(); + } + { + using namespace ov::gen_pattern; + auto node_2 = std::make_shared(element::f32, data_shape_1); + auto node_0 = std::make_shared(element::f32, data_shape_2); + auto Constant_1200 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Unsqueeze_1201 = makeOP({node_2, Constant_1200}); + auto Constant_1202 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Unsqueeze_1203 = makeOP({Unsqueeze_1201, Constant_1202}); + auto ShapeOf_1206 = makeOP({Unsqueeze_1203}, {{"output_type", "i64"}}); + auto Constant_1226 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Constant_1227 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_1229 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1230 = + makeOP({ShapeOf_1206, Constant_1226, Constant_1227, Constant_1229}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Constant_1218 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_1208 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_1209 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {4}); + auto Constant_1211 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1212 = + makeOP({ShapeOf_1206, Constant_1208, Constant_1209, Constant_1211}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_1219 = makeOP({Constant_1218, StridedSlice_1212}, {{"mode", "numpy"}}); + auto Constant_1204 = makeConst(element::i64, + ov::Shape({ + 5, + }), + {0, 4, 3, 1, 2}); + auto Transpose_1205 = makeOP({node_0, Constant_1204}); + auto ShapeOf_1207 = makeOP({Transpose_1205}, {{"output_type", "i64"}}); + auto Constant_1213 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_1214 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {5}); + auto Constant_1216 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1217 = + makeOP({ShapeOf_1207, Constant_1213, Constant_1214, Constant_1216}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_1222 = + makeOP({Broadcast_1219, StridedSlice_1217}, {{"mode", "bidirectional"}}); + auto ShapeOf_1225 = makeOP({Broadcast_1222}, {{"output_type", "i64"}}); + auto Concat_1231 = makeOP({StridedSlice_1230, ShapeOf_1225}, {{"axis", 0}}); + auto Broadcast_1232 = makeOP({Unsqueeze_1203, Concat_1231}, {{"mode", "bidirectional"}}); + auto Constant_1244 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto ReduceProd_1245 = makeOP({StridedSlice_1230, Constant_1244}, {{"keep_dims", true}}); + auto ReduceProd_1247 = makeOP({ShapeOf_1225, {0}}, {{"keep_dims", true}}); + auto Concat_1248 = makeOP({ReduceProd_1245, ReduceProd_1247}, {{"axis", 0}}); + auto Reshape_1249 = makeOP({Broadcast_1232, Concat_1248}, {{"special_zero", false}}); + auto Constant_1235 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Constant_1236 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_1238 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1239 = + makeOP({ShapeOf_1207, Constant_1235, Constant_1236, Constant_1238}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Concat_1240 = makeOP({StridedSlice_1239, ShapeOf_1225}, {{"axis", 0}}); + auto Broadcast_1241 = makeOP({Transpose_1205, Concat_1240}, {{"mode", "bidirectional"}}); + auto Constant_1302 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto ReduceProd_1303 = makeOP({StridedSlice_1239, Constant_1302}, {{"keep_dims", true}}); + auto ReduceProd_1305 = makeOP({ShapeOf_1225, {0}}, {{"keep_dims", true}}); + auto Concat_1306 = makeOP({ReduceProd_1303, ReduceProd_1305}, {{"axis", 0}}); + auto Reshape_1307 = makeOP({Broadcast_1241, Concat_1306}, {{"special_zero", false}}); + auto MatMul_1360 = + makeOP({Reshape_1249, Reshape_1307}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Concat_1361 = makeOP({StridedSlice_1230, StridedSlice_1239}, {{"axis", 0}}); + auto Reshape_1362 = makeOP({MatMul_1360, Concat_1361}, {{"special_zero", false}}); + auto Constant_1363 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {1, 0, 2}); + auto node_4 = makeOP({Reshape_1362, Constant_1363}); + model_ref = std::make_shared(NodeVector{node_4}, ParameterVector{node_2, node_0}); + } +} + +TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { + Shape data_shape_1 = {1, 3, 2, 1, 3, 1}; + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto einsum = std::make_shared(OutputVector{data_1}, "ij...iji->j...i"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1}); + manager.register_pass(); + manager.register_pass(); + } + { + using namespace ov::gen_pattern; + auto node_0 = std::make_shared(element::f32, data_shape_1); + auto Multiply_1382 = makeConst( + element::f32, + ov::Shape({ + 1, + 3, + 1, + 1, + 3, + 1, + }), + {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); + auto Multiply_1383 = makeOP({node_0, Multiply_1382}, {{"auto_broadcast", "numpy"}}); + auto Constant_1384 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {3, 4, 5}); + auto ReduceSum_1385 = makeOP({Multiply_1383, Constant_1384}, {{"keep_dims", false}}); + auto Constant_1386 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {1, 2, 0}); + auto node_2 = makeOP({ReduceSum_1385, Constant_1386}); + model_ref = std::make_shared(NodeVector{node_2}, ParameterVector{node_0}); + } +} + +TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) { + PartialShape data_shape_1 = PartialShape::dynamic(5); + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto einsum = std::make_shared(OutputVector{data_1}, "ij...iji->j...i"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1}); + manager.register_pass(); + } + { + using namespace ov::gen_pattern; + auto node_0 = std::make_shared(element::f32, data_shape_1); + auto Constant_2112 = makeConst(element::i64, ov::Shape({}), {0}); + auto ShapeOf_2109 = makeOP({node_0}, {{"output_type", "i64"}}); + auto Constant_2110 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {0, 2, 4}); + auto Gather_2114 = makeOP({ShapeOf_2109, Constant_2110, Constant_2112}, {{"batch_dims", 0}}); + auto ReduceProd_2526 = makeOP({Gather_2114, Constant_2112}, {{"keep_dims", true}}); + auto Constant_2527 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_2528 = + makeOP({Constant_2112, ReduceProd_2526, Constant_2527}, {{"mode", "numpy"}}); + auto Gather_2115 = makeOP({Gather_2114, Constant_2112, Constant_2112}, {{"batch_dims", 0}}); + auto Constant_2111 = makeConst(element::i64, ov::Shape({}), {3}); + auto Power_2116 = makeOP({Gather_2115, Constant_2111}, {{"auto_broadcast", "numpy"}}); + auto Constant_2113 = makeConst(element::i64, ov::Shape({}), {1}); + auto Subtract_2117 = makeOP({Power_2116, Constant_2113}, {{"auto_broadcast", "numpy"}}); + auto Maximum_2120 = makeOP({Subtract_2117, Constant_2113}, {{"auto_broadcast", "numpy"}}); + auto Subtract_2118 = makeOP({Gather_2115, Constant_2113}, {{"auto_broadcast", "numpy"}}); + auto Maximum_2119 = makeOP({Subtract_2118, Constant_2113}, {{"auto_broadcast", "numpy"}}); + auto Divide_2121 = + makeOP({Maximum_2120, Maximum_2119}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Range_2122 = makeOP({Constant_2112, Power_2116, Divide_2121}); + auto Unsqueeze_2521 = makeOP({Gather_2115, Constant_2112}); + auto Constant_2522 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_2523 = + makeOP({Constant_2113, Unsqueeze_2521, Constant_2522}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_2557 = + makeOP({Broadcast_2528, Range_2122, Broadcast_2523, Constant_2112}); + auto ShapeOf_2558 = makeOP({ShapeOf_2109}); + auto Constant_2559 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_2560 = + makeOP({Constant_2113, ShapeOf_2558, Constant_2559}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_2563 = + makeOP({Broadcast_2560, Constant_2110, Gather_2114, Constant_2112}); + auto Reshape_2564 = makeOP({ScatterElementsUpdate_2557, ScatterElementsUpdate_2563}, + {{"special_zero", false}}); + auto Convert_2565 = makeOP({Reshape_2564}, {{"destination_type", "f32"}}); + auto Constant_2569 = makeConst(element::i64, ov::Shape({}), {0}); + auto ShapeOf_2566 = makeOP({node_0}, {{"output_type", "i64"}}); + auto Constant_2567 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {1, 3}); + auto Gather_2571 = makeOP({ShapeOf_2566, Constant_2567, Constant_2569}, {{"batch_dims", 0}}); + auto ReduceProd_2983 = makeOP({Gather_2571, Constant_2569}, {{"keep_dims", true}}); + auto Constant_2984 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_2985 = + makeOP({Constant_2569, ReduceProd_2983, Constant_2984}, {{"mode", "numpy"}}); + auto Gather_2572 = makeOP({Gather_2571, Constant_2569, Constant_2569}, {{"batch_dims", 0}}); + auto Constant_2568 = makeConst(element::i64, ov::Shape({}), {2}); + auto Power_2573 = makeOP({Gather_2572, Constant_2568}, {{"auto_broadcast", "numpy"}}); + auto Constant_2570 = makeConst(element::i64, ov::Shape({}), {1}); + auto Subtract_2574 = makeOP({Power_2573, Constant_2570}, {{"auto_broadcast", "numpy"}}); + auto Maximum_2577 = makeOP({Subtract_2574, Constant_2570}, {{"auto_broadcast", "numpy"}}); + auto Subtract_2575 = makeOP({Gather_2572, Constant_2570}, {{"auto_broadcast", "numpy"}}); + auto Maximum_2576 = makeOP({Subtract_2575, Constant_2570}, {{"auto_broadcast", "numpy"}}); + auto Divide_2578 = + makeOP({Maximum_2577, Maximum_2576}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Range_2579 = makeOP({Constant_2569, Power_2573, Divide_2578}); + auto Unsqueeze_2978 = makeOP({Gather_2572, Constant_2569}); + auto Constant_2979 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_2980 = + makeOP({Constant_2570, Unsqueeze_2978, Constant_2979}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_3014 = + makeOP({Broadcast_2985, Range_2579, Broadcast_2980, Constant_2569}); + auto ShapeOf_3015 = makeOP({ShapeOf_2566}); + auto Constant_3016 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_3017 = + makeOP({Constant_2570, ShapeOf_3015, Constant_3016}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_3020 = + makeOP({Broadcast_3017, Constant_2567, Gather_2571, Constant_2569}); + auto Reshape_3021 = makeOP({ScatterElementsUpdate_3014, ScatterElementsUpdate_3020}, + {{"special_zero", false}}); + auto Convert_3022 = makeOP({Reshape_3021}, {{"destination_type", "f32"}}); + auto Multiply_3023 = makeOP({Convert_2565, Convert_3022}, {{"auto_broadcast", "numpy"}}); + auto Multiply_3024 = makeOP({node_0, Multiply_3023}, {{"auto_broadcast", "numpy"}}); + auto Constant_3025 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {2, 3, 4}); + auto ReduceSum_3026 = makeOP({Multiply_3024, Constant_3025}, {{"keep_dims", false}}); + auto Constant_3027 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {1, 0}); + auto node_2 = makeOP({ReduceSum_3026, Constant_3027}); + model_ref = std::make_shared(NodeVector{node_2}, ParameterVector{node_0}); + } +} + +TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_static_cf) { + PartialShape data_shape_1 = {1, 2, 2, 1, 1, 1}; + PartialShape data_shape_2 = {4, 1, 1, 1, 1, 1}; + PartialShape data_shape_3 = {3, 1, 3, 3}; + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto data_3 = std::make_shared(element::f32, data_shape_3); + auto einsum = + std::make_shared(OutputVector{data_1, data_2, data_3}, "ba...b,bcccdd,...dbcc->c...b"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1, data_2, data_3}); + manager.register_pass(); + manager.register_pass(); + } + { + using namespace ov::gen_pattern; + auto node_0 = std::make_shared(element::f32, data_shape_3); + auto node_2 = std::make_shared(element::f32, data_shape_2); + auto node_4 = std::make_shared(element::f32, data_shape_1); + auto Multiply_1990 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + 1, + 1, + }), + {1.000000f}); + auto Multiply_1991 = makeOP({node_2, Multiply_1990}, {{"auto_broadcast", "numpy"}}); + auto Constant_1992 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {2, 3, 5}); + auto ReduceSum_1993 = makeOP({Multiply_1991, Constant_1992}, {{"keep_dims", false}}); + auto Concat_2034 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {4, 3, 3}); + auto Broadcast_2035 = makeOP({ReduceSum_1993, Concat_2034}, {{"mode", "bidirectional"}}); + auto Concat_2051 = makeConst(element::i64, + ov::Shape({ + 4, + }), + {4, 3, 3, 1}); + auto Reshape_2052 = makeOP({Broadcast_2035, Concat_2051}, {{"special_zero", false}}); + auto Convert_1700 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + 1, + 1, + }), + {1.000000f}); + auto Multiply_1701 = makeOP({node_4, Convert_1700}, {{"auto_broadcast", "numpy"}}); + auto Constant_1702 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {5}); + auto ReduceSum_1703 = makeOP({Multiply_1701, Constant_1702}, {{"keep_dims", false}}); + auto Constant_1799 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto ReduceSum_1800 = makeOP({ReduceSum_1703, Constant_1799}, {{"keep_dims", false}}); + auto Constant_1803 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {4, 5}); + auto Unsqueeze_1804 = makeOP({ReduceSum_1800, Constant_1803}); + auto Constant_1605 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Unsqueeze_1606 = makeOP({node_0, Constant_1605}); + auto Constant_1607 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {0, 1}); + auto Unsqueeze_1608 = makeOP({Unsqueeze_1606, Constant_1607}); + auto Convert_1795 = makeConst( + element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + 1, + 3, + 3, + }), + {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); + auto Multiply_1796 = makeOP({Unsqueeze_1608, Convert_1795}, {{"auto_broadcast", "numpy"}}); + auto Constant_1797 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {6}); + auto ReduceSum_1798 = makeOP({Multiply_1796, Constant_1797}, {{"keep_dims", false}}); + auto Constant_1801 = makeConst(element::i64, + ov::Shape({ + 6, + }), + {4, 0, 1, 2, 3, 5}); + auto Transpose_1802 = makeOP({ReduceSum_1798, Constant_1801}); + auto Multiply_1805 = makeOP({Unsqueeze_1804, Transpose_1802}, {{"auto_broadcast", "numpy"}}); + auto Constant_1994 = makeConst(element::i64, + ov::Shape({ + 6, + }), + {0, 5, 1, 2, 3, 4}); + auto Transpose_1995 = makeOP({Multiply_1805, Constant_1994}); + auto Concat_2043 = makeConst(element::i64, + ov::Shape({ + 6, + }), + {4, 3, 2, 1, 1, 3}); + auto Broadcast_2044 = makeOP({Transpose_1995, Concat_2043}, {{"mode", "bidirectional"}}); + auto Concat_2076 = makeConst(element::i64, + ov::Shape({ + 4, + }), + {4, 3, 2, 3}); + auto Reshape_2077 = makeOP({Broadcast_2044, Concat_2076}, {{"special_zero", false}}); + auto MatMul_2116 = + makeOP({Reshape_2052, Reshape_2077}, {{"transpose_a", true}, {"transpose_b", true}}); + auto Concat_2117 = makeConst(element::i64, + ov::Shape({ + 5, + }), + {4, 3, 2, 1, 1}); + auto Reshape_2118 = makeOP({MatMul_2116, Concat_2117}, {{"special_zero", false}}); + auto Constant_2119 = makeConst(element::i64, + ov::Shape({ + 5, + }), + {1, 2, 3, 4, 0}); + auto node_6 = makeOP({Reshape_2118, Constant_2119}); + model_ref = std::make_shared(NodeVector{node_6}, ParameterVector{node_4, node_2, node_0}); + } +} + +TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_dynamic) { + PartialShape data_shape_1 = PartialShape::dynamic(5); + PartialShape data_shape_2 = PartialShape::dynamic(6); + PartialShape data_shape_3 = PartialShape::dynamic(4); + { + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto data_3 = std::make_shared(element::f32, data_shape_3); + auto einsum = + std::make_shared(OutputVector{data_1, data_2, data_3}, "a...b,bcccdd,...dbcc->c...b"); + model = std::make_shared(NodeVector{einsum}, ParameterVector{data_1, data_2, data_3}); + manager.register_pass(); + } + { + using namespace ov::gen_pattern; + auto node_0 = std::make_shared(element::f32, data_shape_3); + auto node_2 = std::make_shared(element::f32, data_shape_2); + auto node_4 = std::make_shared(element::f32, data_shape_1); + auto Constant_904 = makeConst(element::i64, ov::Shape({}), {0}); + auto ShapeOf_901 = makeOP({node_2}, {{"output_type", "i64"}}); + auto Constant_902 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {1, 2, 3}); + auto Gather_906 = makeOP({ShapeOf_901, Constant_902, Constant_904}, {{"batch_dims", 0}}); + auto ReduceProd_1318 = makeOP({Gather_906, Constant_904}, {{"keep_dims", true}}); + auto Constant_1319 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_1320 = + makeOP({Constant_904, ReduceProd_1318, Constant_1319}, {{"mode", "numpy"}}); + auto Gather_907 = makeOP({Gather_906, Constant_904, Constant_904}, {{"batch_dims", 0}}); + auto Constant_903 = makeConst(element::i64, ov::Shape({}), {3}); + auto Power_908 = makeOP({Gather_907, Constant_903}, {{"auto_broadcast", "numpy"}}); + auto Constant_905 = makeConst(element::i64, ov::Shape({}), {1}); + auto Subtract_909 = makeOP({Power_908, Constant_905}, {{"auto_broadcast", "numpy"}}); + auto Maximum_912 = makeOP({Subtract_909, Constant_905}, {{"auto_broadcast", "numpy"}}); + auto Subtract_910 = makeOP({Gather_907, Constant_905}, {{"auto_broadcast", "numpy"}}); + auto Maximum_911 = makeOP({Subtract_910, Constant_905}, {{"auto_broadcast", "numpy"}}); + auto Divide_913 = + makeOP({Maximum_912, Maximum_911}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Range_914 = makeOP({Constant_904, Power_908, Divide_913}); + auto Unsqueeze_1313 = makeOP({Gather_907, Constant_904}); + auto Constant_1314 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_1315 = + makeOP({Constant_905, Unsqueeze_1313, Constant_1314}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_1349 = + makeOP({Broadcast_1320, Range_914, Broadcast_1315, Constant_904}); + auto ShapeOf_1350 = makeOP({ShapeOf_901}); + auto Constant_1351 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_1352 = + makeOP({Constant_905, ShapeOf_1350, Constant_1351}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_1355 = + makeOP({Broadcast_1352, Constant_902, Gather_906, Constant_904}); + auto Reshape_1356 = makeOP({ScatterElementsUpdate_1349, ScatterElementsUpdate_1355}, + {{"special_zero", false}}); + auto Convert_1357 = makeOP({Reshape_1356}, {{"destination_type", "f32"}}); + auto Constant_1361 = makeConst(element::i64, ov::Shape({}), {0}); + auto ShapeOf_1358 = makeOP({node_2}, {{"output_type", "i64"}}); + auto Constant_1359 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {4, 5}); + auto Gather_1363 = makeOP({ShapeOf_1358, Constant_1359, Constant_1361}, {{"batch_dims", 0}}); + auto ReduceProd_1775 = makeOP({Gather_1363, Constant_1361}, {{"keep_dims", true}}); + auto Constant_1776 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_1777 = + makeOP({Constant_1361, ReduceProd_1775, Constant_1776}, {{"mode", "numpy"}}); + auto Gather_1364 = makeOP({Gather_1363, Constant_1361, Constant_1361}, {{"batch_dims", 0}}); + auto Constant_1360 = makeConst(element::i64, ov::Shape({}), {2}); + auto Power_1365 = makeOP({Gather_1364, Constant_1360}, {{"auto_broadcast", "numpy"}}); + auto Constant_1362 = makeConst(element::i64, ov::Shape({}), {1}); + auto Subtract_1366 = makeOP({Power_1365, Constant_1362}, {{"auto_broadcast", "numpy"}}); + auto Maximum_1369 = makeOP({Subtract_1366, Constant_1362}, {{"auto_broadcast", "numpy"}}); + auto Subtract_1367 = makeOP({Gather_1364, Constant_1362}, {{"auto_broadcast", "numpy"}}); + auto Maximum_1368 = makeOP({Subtract_1367, Constant_1362}, {{"auto_broadcast", "numpy"}}); + auto Divide_1370 = + makeOP({Maximum_1369, Maximum_1368}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Range_1371 = makeOP({Constant_1361, Power_1365, Divide_1370}); + auto Unsqueeze_1770 = makeOP({Gather_1364, Constant_1361}); + auto Constant_1771 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_1772 = + makeOP({Constant_1362, Unsqueeze_1770, Constant_1771}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_1806 = + makeOP({Broadcast_1777, Range_1371, Broadcast_1772, Constant_1361}); + auto ShapeOf_1807 = makeOP({ShapeOf_1358}); + auto Constant_1808 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_1809 = + makeOP({Constant_1362, ShapeOf_1807, Constant_1808}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_1812 = + makeOP({Broadcast_1809, Constant_1359, Gather_1363, Constant_1361}); + auto Reshape_1813 = makeOP({ScatterElementsUpdate_1806, ScatterElementsUpdate_1812}, + {{"special_zero", false}}); + auto Convert_1814 = makeOP({Reshape_1813}, {{"destination_type", "f32"}}); + auto Multiply_1815 = makeOP({Convert_1357, Convert_1814}, {{"auto_broadcast", "numpy"}}); + auto Multiply_1816 = makeOP({node_2, Multiply_1815}, {{"auto_broadcast", "numpy"}}); + auto Constant_1817 = makeConst(element::i64, + ov::Shape({ + 3, + }), + {2, 3, 5}); + auto ReduceSum_1818 = makeOP({Multiply_1816, Constant_1817}, {{"keep_dims", false}}); + auto Constant_1833 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto ShapeOf_1821 = makeOP({ReduceSum_1818}, {{"output_type", "i64"}}); + auto Constant_1823 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Constant_1824 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_1826 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1827 = + makeOP({ShapeOf_1821, Constant_1823, Constant_1824, Constant_1826}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_1834 = makeOP({Constant_1833, StridedSlice_1827}, {{"mode", "numpy"}}); + auto Constant_894 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto ReduceSum_895 = makeOP({node_4, Constant_894}, {{"keep_dims", false}}); + auto Constant_898 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {4, 5}); + auto Unsqueeze_899 = makeOP({ReduceSum_895, Constant_898}); + auto Constant_430 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Unsqueeze_431 = makeOP({node_0, Constant_430}); + auto Constant_432 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {0, 1}); + auto Unsqueeze_433 = makeOP({Unsqueeze_431, Constant_432}); + auto Constant_437 = makeConst(element::i64, ov::Shape({}), {0}); + auto ShapeOf_434 = makeOP({Unsqueeze_433}, {{"output_type", "i64"}}); + auto Constant_435 = makeConst(element::i64, + ov::Shape({ + 2, + }), + {5, 6}); + auto Gather_439 = makeOP({ShapeOf_434, Constant_435, Constant_437}, {{"batch_dims", 0}}); + auto ReduceProd_851 = makeOP({Gather_439, Constant_437}, {{"keep_dims", true}}); + auto Constant_852 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_853 = + makeOP({Constant_437, ReduceProd_851, Constant_852}, {{"mode", "numpy"}}); + auto Gather_440 = makeOP({Gather_439, Constant_437, Constant_437}, {{"batch_dims", 0}}); + auto Constant_436 = makeConst(element::i64, ov::Shape({}), {2}); + auto Power_441 = makeOP({Gather_440, Constant_436}, {{"auto_broadcast", "numpy"}}); + auto Constant_438 = makeConst(element::i64, ov::Shape({}), {1}); + auto Subtract_442 = makeOP({Power_441, Constant_438}, {{"auto_broadcast", "numpy"}}); + auto Maximum_445 = makeOP({Subtract_442, Constant_438}, {{"auto_broadcast", "numpy"}}); + auto Subtract_443 = makeOP({Gather_440, Constant_438}, {{"auto_broadcast", "numpy"}}); + auto Maximum_444 = makeOP({Subtract_443, Constant_438}, {{"auto_broadcast", "numpy"}}); + auto Divide_446 = + makeOP({Maximum_445, Maximum_444}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto Range_447 = makeOP({Constant_437, Power_441, Divide_446}); + auto Unsqueeze_846 = makeOP({Gather_440, Constant_437}); + auto Constant_847 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_848 = + makeOP({Constant_438, Unsqueeze_846, Constant_847}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_882 = + makeOP({Broadcast_853, Range_447, Broadcast_848, Constant_437}); + auto ShapeOf_883 = makeOP({ShapeOf_434}); + auto Constant_884 = makeConst(element::u8, ov::Shape({}), {0}); + auto Broadcast_885 = makeOP({Constant_438, ShapeOf_883, Constant_884}, {{"mode", "numpy"}}); + auto ScatterElementsUpdate_888 = + makeOP({Broadcast_885, Constant_435, Gather_439, Constant_437}); + auto Reshape_889 = + makeOP({ScatterElementsUpdate_882, ScatterElementsUpdate_888}, {{"special_zero", false}}); + auto Convert_890 = makeOP({Reshape_889}, {{"destination_type", "f32"}}); + auto Multiply_891 = makeOP({Unsqueeze_433, Convert_890}, {{"auto_broadcast", "numpy"}}); + auto Constant_892 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {6}); + auto ReduceSum_893 = makeOP({Multiply_891, Constant_892}, {{"keep_dims", false}}); + auto Constant_896 = makeConst(element::i64, + ov::Shape({ + 6, + }), + {0, 1, 2, 4, 3, 5}); + auto Transpose_897 = makeOP({ReduceSum_893, Constant_896}); + auto Multiply_900 = makeOP({Unsqueeze_899, Transpose_897}, {{"auto_broadcast", "numpy"}}); + auto Constant_1819 = makeConst(element::i64, + ov::Shape({ + 6, + }), + {3, 5, 0, 1, 2, 4}); + auto Transpose_1820 = makeOP({Multiply_900, Constant_1819}); + auto ShapeOf_1822 = makeOP({Transpose_1820}, {{"output_type", "i64"}}); + auto Constant_1828 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Constant_1829 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_1831 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1832 = + makeOP({ShapeOf_1822, Constant_1828, Constant_1829, Constant_1831}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_1837 = + makeOP({Broadcast_1834, StridedSlice_1832}, {{"mode", "bidirectional"}}); + auto ShapeOf_1840 = makeOP({Broadcast_1837}, {{"output_type", "i64"}}); + auto Constant_1851 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_1841 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_1842 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {3}); + auto Constant_1844 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1845 = + makeOP({ShapeOf_1821, Constant_1841, Constant_1842, Constant_1844}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_1852 = makeOP({Constant_1851, StridedSlice_1845}, {{"mode", "numpy"}}); + auto Constant_1846 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {5}); + auto Constant_1847 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {6}); + auto Constant_1849 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1850 = + makeOP({ShapeOf_1822, Constant_1846, Constant_1847, Constant_1849}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Broadcast_1855 = + makeOP({Broadcast_1852, StridedSlice_1850}, {{"mode", "bidirectional"}}); + auto ShapeOf_1858 = makeOP({Broadcast_1855}, {{"output_type", "i64"}}); + auto Concat_1859 = makeOP({ShapeOf_1840, ShapeOf_1858}, {{"axis", 0}}); + auto Broadcast_1860 = makeOP({ReduceSum_1818, Concat_1859}, {{"mode", "bidirectional"}}); + auto ReduceProd_1875 = makeOP({ShapeOf_1858, {0}}, {{"keep_dims", true}}); + auto Constant_1873 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Concat_1876 = makeOP({ShapeOf_1840, ReduceProd_1875, Constant_1873}, {{"axis", 0}}); + auto Reshape_1903 = makeOP({Broadcast_1860, Concat_1876}, {{"special_zero", false}}); + auto Constant_1863 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {2}); + auto Constant_1864 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {5}); + auto Constant_1866 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto StridedSlice_1867 = + makeOP({ShapeOf_1822, Constant_1863, Constant_1864, Constant_1866}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Concat_1868 = makeOP({ShapeOf_1840, StridedSlice_1867, ShapeOf_1858}, {{"axis", 0}}); + auto Broadcast_1869 = makeOP({Transpose_1820, Concat_1868}, {{"mode", "bidirectional"}}); + auto Constant_1904 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto ReduceProd_1905 = makeOP({StridedSlice_1867, Constant_1904}, {{"keep_dims", true}}); + auto ReduceProd_1907 = makeOP({ShapeOf_1858, {0}}, {{"keep_dims", true}}); + auto Concat_1908 = makeOP({ShapeOf_1840, ReduceProd_1905, ReduceProd_1907}, {{"axis", 0}}); + auto Reshape_1961 = makeOP({Broadcast_1869, Concat_1908}, {{"special_zero", false}}); + auto MatMul_1962 = + makeOP({Reshape_1903, Reshape_1961}, {{"transpose_a", true}, {"transpose_b", true}}); + auto Concat_1963 = makeOP({ShapeOf_1840, StridedSlice_1867}, {{"axis", 0}}); + auto Reshape_1964 = makeOP({MatMul_1962, Concat_1963}, {{"special_zero", false}}); + auto Constant_1965 = makeConst(element::i64, + ov::Shape({ + 5, + }), + {1, 2, 3, 4, 0}); + auto node_6 = makeOP({Reshape_1964, Constant_1965}); + model_ref = std::make_shared(NodeVector{node_6}, ParameterVector{node_0, node_2, node_4}); + } +} From cd422c2779c2d212b66ed84fb8947028aef1ddce Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Mon, 10 Feb 2025 17:38:33 +0100 Subject: [PATCH 23/47] Add missing docstrings for einsum decomposition Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 137 ++++++++++++++++-- 1 file changed, 121 insertions(+), 16 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index c3de50ac32fb14..aa21ea64106fbc 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -171,6 +171,16 @@ void update_operands(ov::OutputVector& input_nodes, } using LabelDimMap = std::unordered_map>; +/// \brief Computes a mapping from labels to dimensions based on the input rank and subscript. +/// +/// This function processes the input subscript to extract labels and maps them to the corresponding +/// dimensions of the input tensor. The function also considers the presence of ellipsis ("...") in +/// the labels and adjusts the dimension map accordingly. +/// +/// \param input_rank The rank of the input tensor. It can be static or dynamic. +/// \param input_subscript The subscript string representing the labels of the input tensor dimensions. +/// \return A map where the keys are labels (strings) and the values are vectors of dimension indices. +/// LabelDimMap compute_label_dim_map(const ov::Rank& input_rank, const std::string& input_subscript) { static const std::string ellipsis = "..."; const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); @@ -204,6 +214,26 @@ LabelDimMap compute_label_dim_map(const ov::Rank& input_rank, const std::string& return resulted_map; } +/// \brief Computes the ranges for common, separated, and reduced labels in the input subscript. +/// +/// This function calculates the start and end indices for common, separated, and reduced labels +/// based on the input rank and subscript. It also considers the presence of ellipsis ("...") in +/// the labels and adjusts the ranks accordingly. +/// +/// \param input_rank The rank of the input tensor. +/// \param input_subscript The subscript string representing the input tensor dimensions. +/// \param common_labels A vector of strings representing the common labels. +/// \param sep_labels A vector of strings representing the separated labels. +/// \param reduced_labels A vector of strings representing the reduced labels. +/// \param common_begin Reference to a size_t variable to store the beginning index of common labels. +/// \param common_end Reference to a size_t variable to store the ending index of common labels. +/// \param sep_begin Reference to a size_t variable to store the beginning index of separated labels. +/// \param sep_end Reference to a size_t variable to store the ending index of separated labels. +/// \param reduced_begin Reference to a size_t variable to store the beginning index of reduced labels. +/// \param reduced_end Reference to a size_t variable to store the ending index of reduced labels. +/// \param is_separated_first Boolean flag indicating whether the separated labels should come before the reduced +/// labels. +/// void compute_ranges(const ov::Rank& input_rank, const std::string& input_subscript, const std::vector& common_labels, @@ -319,6 +349,19 @@ ov::Output unsqueeze_input(const ov::Output& input_node, return unsqueeze->output(0); } +/// \brief Broadcasts and merges two shapes using specified broadcasting rules. +/// +/// This function takes two shapes (shapes_lhs and shapes_rhs) and attempts to broadcast +/// and merge them into a single shape using NumPy and bidirectional broadcasting rules. The resulting +/// broadcasted shape is returned as an OutputVector. If one of the input vectors is empty, the other +/// vector is returned as is. +/// +/// \param shapes_lhs A single element vector containing the left-hand side shape to be broadcasted or empty. +/// \param shapes_rhs A single element vector containing the right-hand side shape to be broadcasted or empty. +/// \param subgraph_nodes A vector to which the nodes created during the broadcasting process are added. +/// \return An OutputVector containing the broadcasted and merged shape. If one of the input vectors is empty, +/// the other vector is returned. +/// ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, ov::OutputVector& shapes_rhs, ov::NodeVector& subgraph_nodes) { @@ -607,10 +650,23 @@ void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); } +/// \brief Builds an n-dimensional identity tensor based on the input node and repeated label dimensions. +/// +/// This function constructs an identity tenosor matching number of dimensions of the number of repeats for a single +/// label. +/// +/// \param input_node The input node for which the identity tensor is to be built. +/// \param repeated_label_dims A vector containing the dimensions of the repeated label. +/// \param subgraph_nodes A vector of operation nodes that is included into +/// a sub-graph decomposing Einsum that is needed for copy_runtime_info +/// +/// \return The final node representing the identity tensor, reshaped to match input rank and correct dimensions +/// with repeated labels. ov::Output build_identity(const ov::Output& input_node, const std::vector& repeated_label_dims, ov::NodeVector& subgraph_nodes) { OPENVINO_ASSERT(repeated_label_dims.size() > 1); + // Create flattened (repeated_label_dims.size())-dimensional eye tensor with 1s on the diagonal. const auto input_shape = std::make_shared(input_node); const auto repeated_label_indices = ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); @@ -634,6 +690,7 @@ ov::Output build_identity(const ov::Output& input_node, const auto eye_flattened = std::make_shared(zeros, eye_flattened_indices, ones, const_0); + // Prepare target shape for identity tensor for specified repeated label dimensions. const auto identity_rank = std::make_shared(input_shape); const auto ones_of_input_shape_rank = std::make_shared(const_1, identity_rank); const auto identity_shape = std::make_shared(ones_of_input_shape_rank, @@ -641,6 +698,7 @@ ov::Output build_identity(const ov::Output& input_node, repeated_dimensions, const_0); + // Reshape the flattened identity tensor to the target shape. const auto identity = std::make_shared(eye_flattened, identity_shape, false); const auto identity_cvt = std::make_shared(identity, input_node.get_element_type()); subgraph_nodes.insert(subgraph_nodes.end(), @@ -671,8 +729,20 @@ ov::Output build_identity(const ov::Output& input_node, return subgraph_nodes.back(); } -ov::Output build_multi_identity(ov::pass::EinsumDecomposition* einsum_decompose_ptr, - const ov::Output& input_node, +/// \brief Builds a multi-identity node by multiplying identity nodes for each repeated label. +/// +/// This function constructs a multi-identity node by iteratively multiplying identity nodes +/// corresponding to each repeated label. The identity nodes are built using the provided +/// input node and label dimension map. +/// +/// \param input_node The input node for which the identity nodes are to be built. +/// \param repeated_labels A vector of repeated labels for which identity nodes are to be created. +/// \param label_dim_map A map from labels to their corresponding dimensions. +/// \param subgraph_nodes A vector of operation nodes that is included into +/// a sub-graph decomposing Einsum that is needed for copy_runtime_info +/// \return The final multi-identity node after multiplying all identity nodes. +/// +ov::Output build_multi_identity(const ov::Output& input_node, const std::vector& repeated_labels, const LabelDimMap& label_dim_map, ov::NodeVector& subgraph_nodes) { @@ -696,8 +766,16 @@ ov::Output build_multi_identity(ov::pass::EinsumDecomposition* einsum_ return subgraph_nodes.back(); } -/// \brief Helper function to fill in the data needed for diagonal extraction - result shape -/// and subscript, repeated labels, axes to reduce. +/// \brief Prepares data for diagonal extraction in Einsum operation. +/// +/// This function processes the input subscript and label-dimension map to identify repeated labels, +/// update the resultant subscript, and determine the axes to be reduced. +/// +/// \param input_subscript The input subscript string representing the Einsum equation. +/// \param label_dim_map A map from labels to their corresponding dimensions. +/// \param resultant_subscript A reference to the resultant subscript string to be updated. +/// \param repeated_labels A reference to a vector of strings to store repeated labels found in the input subscript. +/// \param reduced_axes A reference to an AxisSet to store the axes that need to be reduced. /// void prepare_diagonal_extraction_data(const std::string& input_subscript, const LabelDimMap& label_dim_map, @@ -732,8 +810,20 @@ void prepare_diagonal_extraction_data(const std::string& input_subscript, } } -void extract_diagonal(ov::pass::EinsumDecomposition* einsum_decompose_ptr, - ov::OutputVector& inputs, +/// +/// \brief Extracts the diagonal elements from the input tensor based on the provided subscripts. +/// +/// This function modifies the input tensor by extracting its diagonal elements for repeated lables and updating the +/// corresponding subscript. The diagonal extraction is performed by multiplying the input tensor with a +/// multi-identity. +/// +/// \param inputs A vector of input tensors. +/// \param input_subscripts A vector of subscripts corresponding to each input tensor. +/// \param input_ind The index of the input tensor to be processed. +/// \param subgraph_nodes A vector of operation nodes that is included into +/// a sub-graph decomposing Einsum that is needed for copy_runtime_info +/// +void extract_diagonal(ov::OutputVector& inputs, std::vector& input_subscripts, size_t input_ind, ov::NodeVector& subgraph_nodes) { @@ -758,8 +848,7 @@ void extract_diagonal(ov::pass::EinsumDecomposition* einsum_decompose_ptr, if (repeated_labels.size() == 0) { return; } - const auto multi_identity = - build_multi_identity(einsum_decompose_ptr, input_node, repeated_labels, label_dim_map, subgraph_nodes); + const auto multi_identity = build_multi_identity(input_node, repeated_labels, label_dim_map, subgraph_nodes); // multiply both operands with broadcasting const auto mul = @@ -858,8 +947,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, unsqueeze_ellipses_to_same_rank(input_nodes, input_subscripts, input_ind1, input_ind2, subgraph_nodes); // extract diagonals in case repeated labels in the corresponding input subscripts - extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind1, subgraph_nodes); - extract_diagonal(einsum_decompose_ptr, input_nodes, input_subscripts, input_ind2, subgraph_nodes); + extract_diagonal(input_nodes, input_subscripts, input_ind1, subgraph_nodes); + extract_diagonal(input_nodes, input_subscripts, input_ind2, subgraph_nodes); // reduce dimensions for input operands if possible reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); @@ -1186,12 +1275,25 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, } } // namespace +/// \brief Constructor for the EinsumDecomposition transformation pass. +/// +/// This transformation decomposes the Einsum operation into a sequence of more basic operations. +/// It matches the Einsum operation and replaces it with a sub-graph of operations that perform +/// the same computation. +/// +/// The transformation follows these steps: +/// 1. Parse the Einsum equation to extract input and output subscripts. +/// 2. Check if the transformation is applicable by ensuring all input nodes have static ranks. +/// 3. Compute the optimal path for contracting pairs of operands. +/// 4. Fix inputs where ellipsis does not contain any dimensions. +/// 5. Contract inputs by Einsum until only one input remains. +/// 6. Extract the diagonal for the single remaining operand. +/// 7. Reduce dimensions for the remaining input node. +/// 8. Transpose dimensions to match the layout required by the output subscript. +/// 9. Replace the original Einsum node with the last node from the decomposed sub-graph, +/// preserving the original node's name and runtime information. ov::pass::EinsumDecomposition::EinsumDecomposition() { - // NOTE: The transformation is applicable if Einsum equation does not contain ellipsis label - // and does not contain subscripts with repeated labels. - // For example, the transformation is applicable to Einsum with equation="abc,bd->ad" - // but not applicable to a case with equation="aabc,bd->ad" due to repeated labels - // in the first input subscript. + MATCHER_SCOPE(EinsumDecomposition); auto einsum = ov::pass::pattern::wrap_type(); matcher_pass_callback callback = [this](ov::pass::pattern::Matcher& m) { @@ -1200,6 +1302,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { return false; } + // Parse the Einsum equation to get input and output subscripts auto equation = einsum_node->get_equation(); std::vector input_subscripts; std::string output_subscript; @@ -1209,7 +1312,8 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // and a vector of sub-graph nodes for copy_runtime_info ov::OutputVector input_nodes = einsum_node->input_values(); ov::NodeVector subgraph_nodes; - // check that the transformation is applicable + + // Check if the transformation is applicable by ensuring all input nodes have static ranks if (std::any_of(input_nodes.cbegin(), input_nodes.cend(), [](ov::Output node) { return node.get_partial_shape().rank().is_dynamic(); })) { @@ -1234,6 +1338,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { subgraph_nodes); } + // Ensure only one input node remains after contraction OPENVINO_ASSERT(input_nodes.size() == 1); // extract diagonal for the single operand From 9e3153e1cc37090fa4764968a53c5481261fb5f7 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Mon, 10 Feb 2025 17:49:31 +0100 Subject: [PATCH 24/47] Fix extract diagonal call Signed-off-by: Mateusz Mikolajczyk --- .../transformations/op_conversions/einsum_decomposition.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index aa21ea64106fbc..82722682c853dd 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1293,7 +1293,6 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, /// 9. Replace the original Einsum node with the last node from the decomposed sub-graph, /// preserving the original node's name and runtime information. ov::pass::EinsumDecomposition::EinsumDecomposition() { - MATCHER_SCOPE(EinsumDecomposition); auto einsum = ov::pass::pattern::wrap_type(); matcher_pass_callback callback = [this](ov::pass::pattern::Matcher& m) { @@ -1342,7 +1341,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { OPENVINO_ASSERT(input_nodes.size() == 1); // extract diagonal for the single operand - extract_diagonal(this, input_nodes, input_subscripts, 0, subgraph_nodes); + extract_diagonal(input_nodes, input_subscripts, 0, subgraph_nodes); // reduce dimensions for the remained input node reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // transpose dimensions to layout required by the output subscript From a6cbfd471a2539dd864cfdd10904afedfd3b73e2 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Mon, 10 Feb 2025 18:08:59 +0100 Subject: [PATCH 25/47] Remove dependency on einsum_decompose_ptr Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 82722682c853dd..eb9a42b9d8af19 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -585,7 +585,6 @@ void transpose_input(ov::OutputVector& input_nodes, /// \brief Find labels (in a given input subscript) that are met once in the equation /// and reduce dimensions corresponding to such labels /// -/// \param einsum_decompose_ptr A pointer to Einsum decomposing pass /// \param input_nodes A vector of input nodes to Einsum operation /// \param input_subscripts A vector of corresponding subscripts for the input nodes /// \param output_subscript The output subscript @@ -594,8 +593,7 @@ void transpose_input(ov::OutputVector& input_nodes, /// \param subgraph_nodes A vector of operation nodes that is included into /// a sub-graph decomposing Einsum that is needed for copy_runtime_info /// -void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, - ov::OutputVector& input_nodes, +void reduce_input(ov::OutputVector& input_nodes, std::vector& input_subscripts, const std::string& output_subscript, size_t input_ind, @@ -639,9 +637,7 @@ void reduce_input(ov::pass::EinsumDecomposition* einsum_decompose_ptr, const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; const auto axes_const = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes_vec); - const auto reduce_sum = - einsum_decompose_ptr->register_new_node(input_node, axes_const, false); - + const auto reduce_sum = std::make_shared(input_node, axes_const, false); // update a vector of inputs and input subscripts input_nodes[input_ind] = reduce_sum->output(0); input_subscripts[input_ind] = new_input_subscript; @@ -914,7 +910,6 @@ void unsqueeze_ellipses_to_same_rank(ov::OutputVector& inputs, /// The input nodes for these two operands are removed from input_nodes along with their input /// subscripts /// -/// \param einsum_decompose_ptr A pointer to Einsum decomposing pass /// \param input_nodes A vector of input nodes to Einsum operation /// \param input_subscripts A vector of corresponding subscripts for the input nodes /// \param output_subscript The output subscript @@ -923,8 +918,7 @@ void unsqueeze_ellipses_to_same_rank(ov::OutputVector& inputs, /// \param subgraph_nodes A vector of operation nodes that is included into a /// sub-graph decomposing Einsum that is needed for copy_runtime_info /// -void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, - ov::OutputVector& input_nodes, +void contract_two_inputs(ov::OutputVector& input_nodes, std::vector& input_subscripts, const std::string& output_subscript, size_t input_ind1, @@ -951,8 +945,8 @@ void contract_two_inputs(ov::pass::EinsumDecomposition* einsum_decompose_ptr, extract_diagonal(input_nodes, input_subscripts, input_ind2, subgraph_nodes); // reduce dimensions for input operands if possible - reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); - reduce_input(einsum_decompose_ptr, input_nodes, input_subscripts, output_subscript, input_ind2, subgraph_nodes); + reduce_input(input_nodes, input_subscripts, output_subscript, input_ind1, subgraph_nodes); + reduce_input(input_nodes, input_subscripts, output_subscript, input_ind2, subgraph_nodes); // step 0. split dimensions of both operands into three groups: // 1. dimension indices with the same labels (in both subscripts) that are NOT reduced - @@ -1328,8 +1322,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // contract inputs by Einsum until just one is remained for (auto const& inds_pair : einsum_path) { - contract_two_inputs(this, - input_nodes, + contract_two_inputs(input_nodes, input_subscripts, output_subscript, inds_pair.first, @@ -1343,7 +1336,7 @@ ov::pass::EinsumDecomposition::EinsumDecomposition() { // extract diagonal for the single operand extract_diagonal(input_nodes, input_subscripts, 0, subgraph_nodes); // reduce dimensions for the remained input node - reduce_input(this, input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); + reduce_input(input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // transpose dimensions to layout required by the output subscript transpose_input(input_nodes, input_subscripts, output_subscript, 0, subgraph_nodes); // replace the original Einsum node with the last node from decomposing sub-graph From 0726b36ac8742e6906ff8a5bb3483cea93129297 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Mon, 10 Feb 2025 18:50:28 +0100 Subject: [PATCH 26/47] Minor change to remove duplicated converts in favor of single convertlike in repeated labels Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 10 +++++----- .../einsum_decomposition_test.cpp | 18 ++++++++---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index eb9a42b9d8af19..46fd134c18c1a2 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -12,6 +12,7 @@ #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/einsum.hpp" #include "openvino/op/gather.hpp" @@ -696,7 +697,6 @@ ov::Output build_identity(const ov::Output& input_node, // Reshape the flattened identity tensor to the target shape. const auto identity = std::make_shared(eye_flattened, identity_shape, false); - const auto identity_cvt = std::make_shared(identity, input_node.get_element_type()); subgraph_nodes.insert(subgraph_nodes.end(), {input_shape, repeated_label_indices, @@ -720,8 +720,7 @@ ov::Output build_identity(const ov::Output& input_node, identity_rank, ones_of_input_shape_rank, identity_shape, - identity, - identity_cvt}); + identity}); return subgraph_nodes.back(); } @@ -847,9 +846,10 @@ void extract_diagonal(ov::OutputVector& inputs, const auto multi_identity = build_multi_identity(input_node, repeated_labels, label_dim_map, subgraph_nodes); // multiply both operands with broadcasting + const auto multi_identity_converted = std::make_shared(multi_identity, input_node); const auto mul = - std::make_shared(input_node, multi_identity, ov::op::AutoBroadcastType::NUMPY); - subgraph_nodes.insert(subgraph_nodes.end(), {mul}); + std::make_shared(input_node, multi_identity_converted, ov::op::AutoBroadcastType::NUMPY); + subgraph_nodes.insert(subgraph_nodes.end(), {multi_identity_converted, mul}); const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; const auto axes_const = diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 4d718a4c96557f..4d0c855710236f 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -454,7 +454,6 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) makeOP({Broadcast_2560, Constant_2110, Gather_2114, Constant_2112}); auto Reshape_2564 = makeOP({ScatterElementsUpdate_2557, ScatterElementsUpdate_2563}, {{"special_zero", false}}); - auto Convert_2565 = makeOP({Reshape_2564}, {{"destination_type", "f32"}}); auto Constant_2569 = makeConst(element::i64, ov::Shape({}), {0}); auto ShapeOf_2566 = makeOP({node_0}, {{"output_type", "i64"}}); auto Constant_2567 = makeConst(element::i64, @@ -492,9 +491,9 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) makeOP({Broadcast_3017, Constant_2567, Gather_2571, Constant_2569}); auto Reshape_3021 = makeOP({ScatterElementsUpdate_3014, ScatterElementsUpdate_3020}, {{"special_zero", false}}); - auto Convert_3022 = makeOP({Reshape_3021}, {{"destination_type", "f32"}}); - auto Multiply_3023 = makeOP({Convert_2565, Convert_3022}, {{"auto_broadcast", "numpy"}}); - auto Multiply_3024 = makeOP({node_0, Multiply_3023}, {{"auto_broadcast", "numpy"}}); + auto Multiply_3023 = makeOP({Reshape_2564, Reshape_3021}, {{"auto_broadcast", "numpy"}}); + auto ConvertLike_3024 = makeOP({Multiply_3023, node_0}); + auto Multiply_3024 = makeOP({node_0, ConvertLike_3024}, {{"auto_broadcast", "numpy"}}); auto Constant_3025 = makeConst(element::i64, ov::Shape({ 3, @@ -717,7 +716,6 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d makeOP({Broadcast_1352, Constant_902, Gather_906, Constant_904}); auto Reshape_1356 = makeOP({ScatterElementsUpdate_1349, ScatterElementsUpdate_1355}, {{"special_zero", false}}); - auto Convert_1357 = makeOP({Reshape_1356}, {{"destination_type", "f32"}}); auto Constant_1361 = makeConst(element::i64, ov::Shape({}), {0}); auto ShapeOf_1358 = makeOP({node_2}, {{"output_type", "i64"}}); auto Constant_1359 = makeConst(element::i64, @@ -755,9 +753,9 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d makeOP({Broadcast_1809, Constant_1359, Gather_1363, Constant_1361}); auto Reshape_1813 = makeOP({ScatterElementsUpdate_1806, ScatterElementsUpdate_1812}, {{"special_zero", false}}); - auto Convert_1814 = makeOP({Reshape_1813}, {{"destination_type", "f32"}}); - auto Multiply_1815 = makeOP({Convert_1357, Convert_1814}, {{"auto_broadcast", "numpy"}}); - auto Multiply_1816 = makeOP({node_2, Multiply_1815}, {{"auto_broadcast", "numpy"}}); + auto Multiply_1815 = makeOP({Reshape_1356, Reshape_1813}, {{"auto_broadcast", "numpy"}}); + auto ConvertLike_1816 = makeOP({Multiply_1815, node_2}); + auto Multiply_1816 = makeOP({node_2, ConvertLike_1816}, {{"auto_broadcast", "numpy"}}); auto Constant_1817 = makeConst(element::i64, ov::Shape({ 3, @@ -853,8 +851,8 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d makeOP({Broadcast_885, Constant_435, Gather_439, Constant_437}); auto Reshape_889 = makeOP({ScatterElementsUpdate_882, ScatterElementsUpdate_888}, {{"special_zero", false}}); - auto Convert_890 = makeOP({Reshape_889}, {{"destination_type", "f32"}}); - auto Multiply_891 = makeOP({Unsqueeze_433, Convert_890}, {{"auto_broadcast", "numpy"}}); + auto ConvertLike_890 = makeOP({Reshape_889, Unsqueeze_433}); + auto Multiply_891 = makeOP({Unsqueeze_433, ConvertLike_890}, {{"auto_broadcast", "numpy"}}); auto Constant_892 = makeConst(element::i64, ov::Shape({ 1, From 1d0276dc90e062ae49efce2908c8e88b3788a554 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Tue, 11 Feb 2025 10:05:13 +0100 Subject: [PATCH 27/47] Fix callback lambda Signed-off-by: Mateusz Mikolajczyk --- .../src/transformations/op_conversions/einsum_decomposition.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 46fd134c18c1a2..05da3e0bd3dd15 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -1289,7 +1289,7 @@ void fix_inputs_with_0d_ellipsis(ov::OutputVector& input_nodes, ov::pass::EinsumDecomposition::EinsumDecomposition() { MATCHER_SCOPE(EinsumDecomposition); auto einsum = ov::pass::pattern::wrap_type(); - matcher_pass_callback callback = [this](ov::pass::pattern::Matcher& m) { + matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { auto einsum_node = ov::as_type_ptr(m.get_match_root()); if (!einsum_node) { return false; From e3867a28eaaec7c3ac87e4901ff829f42aa2b8a2 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Tue, 11 Feb 2025 17:55:18 +0100 Subject: [PATCH 28/47] Improve redability for first two einsum decomposition test cases Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 191 +++++++++++------- 1 file changed, 119 insertions(+), 72 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 4d0c855710236f..05cf0eebc4bfa8 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -15,7 +15,7 @@ using namespace ov; TEST_F(TransformationTestsF, Einsum_2in_matmul) { PartialShape data_shape_1{5, 2}; - PartialShape data_shape_2{10, 2, 25}; + PartialShape data_shape_2{10, 1, 25}; { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); @@ -26,9 +26,13 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul) { { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); + + // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. auto order_2 = ov::op::v0::Constant::create(element::i64, {3}, {0, 2, 1}); auto transpose_2 = std::make_shared(data_2, order_2); + // Broadcast data_1 and data_2 to common broadcasted shapes for common and reduced subshapes. + // Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. auto broadcast_shape_constant_1 = ov::op::v0::Constant::create(element::i64, Shape{data_shape_1.size()}, {5, 2}); auto broadcast_shape_constant_2 = @@ -39,13 +43,19 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul) { auto broadcast_2 = std::make_shared(transpose_2, broadcast_shape_constant_2, ov::op::BroadcastType::BIDIRECTIONAL); + // Optionally reshape broadcasted data_1 and data_2 so separate and reduced labels are represented by one + // dimension. Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. auto shape_constant_1 = ov::op::v0::Constant::create(element::i64, Shape{2}, {5, 2}); auto shape_constant_2 = ov::op::v0::Constant::create(element::i64, Shape{2}, {250, 2}); auto reshape_1 = std::make_shared(broadcast_1, shape_constant_1, false); auto reshape_2 = std::make_shared(broadcast_2, shape_constant_2, false); + // Apply MatMul operation for formatted inputs. auto matmul = std::make_shared(reshape_1, reshape_2, false, true); + // Optionally reshape back by unrolling dimensions corresponding to separate labels if needed. + // Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. auto shape_out = ov::op::v0::Constant::create(element::i64, {3}, {5, 10, 25}); auto reshape_out = std::make_shared(matmul, shape_out, false); + // Transpose to the original order of output labels. auto order_out = ov::op::v0::Constant::create(element::i64, {3}, {1, 0, 2}); auto transpose_out = std::make_shared(reshape_out, order_out); @@ -65,35 +75,20 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { } { using namespace ov::gen_pattern; - auto node_2 = std::make_shared(element::f32, data_shape_1); - auto node_0 = std::make_shared(element::f32, data_shape_2); - auto ShapeOf_487 = makeOP({node_2}, {{"output_type", "i64"}}); - auto Constant_507 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_508 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_510 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_511 = makeOP({ShapeOf_487, Constant_507, Constant_508, Constant_510}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_499 = makeConst(element::i64, + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. + auto Constant_485 = makeConst(element::i64, ov::Shape({ - 1, + 3, }), - {1}); + {0, 2, 1}); + auto Transpose_486 = makeOP({data_2, Constant_485}); + // Get shapes of data_1 and data_2. + auto ShapeOf_data_1 = makeOP({data_1}, {{"output_type", "i64"}}); + auto ShapeOf_data_2 = makeOP({Transpose_486}, {{"output_type", "i64"}}); + + // Get reduced subshape for data_1. auto Constant_489 = makeConst(element::i64, ov::Shape({ 1, @@ -109,20 +104,13 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { 1, }), {1}); - auto StridedSlice_493 = makeOP({ShapeOf_487, Constant_489, Constant_490, Constant_492}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Broadcast_500 = makeOP({Constant_499, StridedSlice_493}, {{"mode", "numpy"}}); - auto Constant_485 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {0, 2, 1}); - auto Transpose_486 = makeOP({node_0, Constant_485}); - auto ShapeOf_488 = makeOP({Transpose_486}, {{"output_type", "i64"}}); + auto reduced1 = makeOP({ShapeOf_data_1, Constant_489, Constant_490, Constant_492}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + // Get reduced subshape for data_2. auto Constant_494 = makeConst(element::i64, ov::Shape({ 1, @@ -138,25 +126,49 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { 1, }), {1}); - auto StridedSlice_498 = makeOP({ShapeOf_488, Constant_494, Constant_495, Constant_497}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Broadcast_503 = makeOP({Broadcast_500, StridedSlice_498}, {{"mode", "bidirectional"}}); - auto ShapeOf_506 = makeOP({Broadcast_503}, {{"output_type", "i64"}}); - auto Concat_512 = makeOP({StridedSlice_511, ShapeOf_506}, {{"axis", 0}}); - auto Broadcast_513 = makeOP({node_2, Concat_512}, {{"mode", "bidirectional"}}); - auto Constant_525 = makeConst(element::i64, + auto reduced_2 = makeOP({ShapeOf_data_2, Constant_494, Constant_495, Constant_497}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + // broadcast_merge_shapes(reduced1, reduced_2) + auto Constant_499 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Broadcast_500 = makeOP({Constant_499, reduced1}, {{"mode", "numpy"}}); + auto Broadcast_503 = makeOP({Broadcast_500, reduced_2}, {{"mode", "bidirectional"}}); + auto reduced_subshape_broadcast_merge_shapes = + makeOP({Broadcast_503}, {{"output_type", "i64"}}); + + // Extract separate subshape for data_1. + auto Constant_507 = makeConst(element::i64, ov::Shape({ 1, }), {0}); - auto ReduceProd_526 = makeOP({StridedSlice_511, Constant_525}, {{"keep_dims", true}}); - auto ReduceProd_528 = makeOP({ShapeOf_506, {0}}, {{"keep_dims", true}}); - auto Concat_529 = makeOP({ReduceProd_526, ReduceProd_528}, {{"axis", 0}}); - auto Reshape_530 = makeOP({Broadcast_513, Concat_529}, {{"special_zero", false}}); + auto Constant_508 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_510 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto separate1_subshape = + makeOP({ShapeOf_data_1, Constant_507, Constant_508, Constant_510}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + // Extract separate subshape for data_2. auto Constant_516 = makeConst(element::i64, ov::Shape({ 1, @@ -172,30 +184,65 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { 1, }), {1}); - auto StridedSlice_520 = makeOP({ShapeOf_488, Constant_516, Constant_517, Constant_519}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Concat_521 = makeOP({StridedSlice_520, ShapeOf_506}, {{"axis", 0}}); - auto Broadcast_522 = makeOP({Transpose_486, Concat_521}, {{"mode", "bidirectional"}}); + auto separate2_subshape = + makeOP({ShapeOf_data_2, Constant_516, Constant_517, Constant_519}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + // Broadcast data_1 and data_2 based on caluculated subshapes. + auto Concat_512 = + makeOP({separate1_subshape, reduced_subshape_broadcast_merge_shapes}, {{"axis", 0}}); + auto Broadcast_data_1 = makeOP({data_1, Concat_512}, {{"mode", "bidirectional"}}); + auto Concat_521 = + makeOP({separate2_subshape, reduced_subshape_broadcast_merge_shapes}, {{"axis", 0}}); + auto Broadcast_data_2 = makeOP({Transpose_486, Concat_521}, {{"mode", "bidirectional"}}); + + // Optionally reshape broadcasted data_1 and data_2 so separate and reduced labels are represented by one + // dimension. Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. + // Reshape 1 + auto Constant_525 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + // Reduce separate and reduced + auto Separate1_subshape_red = + makeOP({separate1_subshape, Constant_525}, {{"keep_dims", true}}); + auto Reduced1_subshape_red = + makeOP({reduced_subshape_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); + // Merge subshapes + auto reshape_subshape1 = makeOP({Separate1_subshape_red, Reduced1_subshape_red}, {{"axis", 0}}); + auto Reshape_1 = makeOP({Broadcast_data_1, reshape_subshape1}, {{"special_zero", false}}); + // Reshape 2 auto Constant_569 = makeConst(element::i64, ov::Shape({ 1, }), {0}); - auto ReduceProd_570 = makeOP({StridedSlice_520, Constant_569}, {{"keep_dims", true}}); - auto ReduceProd_572 = makeOP({ShapeOf_506, {0}}, {{"keep_dims", true}}); - auto Concat_573 = makeOP({ReduceProd_570, ReduceProd_572}, {{"axis", 0}}); - auto Reshape_574 = makeOP({Broadcast_522, Concat_573}, {{"special_zero", false}}); - auto matmul = std::make_shared(Reshape_530, Reshape_574, false, true); - auto shape_out = makeOP({StridedSlice_511, StridedSlice_520}, {{"axis", 0}}); + // Reduce separate and reduced + auto Separate2_subshape_red = + makeOP({separate2_subshape, Constant_569}, {{"keep_dims", true}}); + auto Reduced2_subshape_red = + makeOP({reduced_subshape_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); + // Merge subshapes + auto reshape_subshape2 = makeOP({Separate2_subshape_red, Reduced2_subshape_red}, {{"axis", 0}}); + auto Reshape_2 = makeOP({Broadcast_data_2, reshape_subshape2}, {{"special_zero", false}}); + + // Apply MatMul operation for formatted inputs. + auto matmul = std::make_shared(Reshape_1, Reshape_2, false, true); + + // Optionally reshape back by unrolling dimensions corresponding to separate labels if needed. + // Target subshapes are calculated broadcast_merge_shapes function and concatenated. + auto shape_out = makeOP({separate1_subshape, separate2_subshape}, {{"axis", 0}}); auto reshape_out = std::make_shared(matmul, shape_out, false); + // Transpose to the original order of output labels. auto order_out = ov::op::v0::Constant::create(element::i64, {3}, {1, 0, 2}); auto transpose_out = std::make_shared(reshape_out, order_out); - model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{node_2, node_0}); + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1, data_2}); } } From c75fa4730e2840173eff57729b82e226ff646e8a Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 13:12:06 +0100 Subject: [PATCH 29/47] Improve redability for einsum decomposition test 3 Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 157 +++++++++++------- 1 file changed, 95 insertions(+), 62 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 05cf0eebc4bfa8..5e0f1acf1cf74a 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -258,48 +258,37 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { } { using namespace ov::gen_pattern; - auto node_2 = std::make_shared(element::f32, data_shape_1); - auto node_0 = std::make_shared(element::f32, data_shape_2); + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + // Process data_1 + // data_1 contains no dimensions at ellipsis label, unsqueeze to allow for broadcasting auto Constant_1200 = makeConst(element::i64, ov::Shape({ 1, }), {2}); - auto Unsqueeze_1201 = makeOP({node_2, Constant_1200}); + auto Unsqueeze_1201 = makeOP({data_1, Constant_1200}); + // Match ranks of dimensions covered by ellipsis labels auto Constant_1202 = makeConst(element::i64, ov::Shape({ 1, }), {2}); - auto Unsqueeze_1203 = makeOP({Unsqueeze_1201, Constant_1202}); - auto ShapeOf_1206 = makeOP({Unsqueeze_1203}, {{"output_type", "i64"}}); - auto Constant_1226 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_1227 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_1229 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1230 = - makeOP({ShapeOf_1206, Constant_1226, Constant_1227, Constant_1229}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Constant_1218 = makeConst(element::i64, + auto data_1_processed = makeOP({Unsqueeze_1201, Constant_1202}); + // Process data_2 + // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. + auto Constant_1204 = makeConst(element::i64, ov::Shape({ - 1, + 5, }), - {1}); + {0, 4, 3, 1, 2}); + auto data_2_processed = makeOP({data_2, Constant_1204}); + + // Get shapes for data_1 and data_2 + auto ShapeOf_data_1 = makeOP({data_1_processed}, {{"output_type", "i64"}}); + auto ShapeOf_data_2 = makeOP({data_2_processed}, {{"output_type", "i64"}}); + + // Get reduced subshape for data_1. auto Constant_1208 = makeConst(element::i64, ov::Shape({ 1, @@ -316,20 +305,13 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { }), {1}); auto StridedSlice_1212 = - makeOP({ShapeOf_1206, Constant_1208, Constant_1209, Constant_1211}, + makeOP({ShapeOf_data_1, Constant_1208, Constant_1209, Constant_1211}, {{"begin_mask", {0}}, {"end_mask", {0}}, {"new_axis_mask", {}}, {"shrink_axis_mask", {}}, {"ellipsis_mask", {}}}); - auto Broadcast_1219 = makeOP({Constant_1218, StridedSlice_1212}, {{"mode", "numpy"}}); - auto Constant_1204 = makeConst(element::i64, - ov::Shape({ - 5, - }), - {0, 4, 3, 1, 2}); - auto Transpose_1205 = makeOP({node_0, Constant_1204}); - auto ShapeOf_1207 = makeOP({Transpose_1205}, {{"output_type", "i64"}}); + // Get reduced subshape for data_2. auto Constant_1213 = makeConst(element::i64, ov::Shape({ 1, @@ -346,26 +328,49 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { }), {1}); auto StridedSlice_1217 = - makeOP({ShapeOf_1207, Constant_1213, Constant_1214, Constant_1216}, + makeOP({ShapeOf_data_2, Constant_1213, Constant_1214, Constant_1216}, {{"begin_mask", {0}}, {"end_mask", {0}}, {"new_axis_mask", {}}, {"shrink_axis_mask", {}}, {"ellipsis_mask", {}}}); + // broadcast_merge_shapes(reduced1, reduced_2) + auto Constant_1218 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Broadcast_1219 = makeOP({Constant_1218, StridedSlice_1212}, {{"mode", "numpy"}}); auto Broadcast_1222 = makeOP({Broadcast_1219, StridedSlice_1217}, {{"mode", "bidirectional"}}); - auto ShapeOf_1225 = makeOP({Broadcast_1222}, {{"output_type", "i64"}}); - auto Concat_1231 = makeOP({StridedSlice_1230, ShapeOf_1225}, {{"axis", 0}}); - auto Broadcast_1232 = makeOP({Unsqueeze_1203, Concat_1231}, {{"mode", "bidirectional"}}); - auto Constant_1244 = makeConst(element::i64, + auto reduced_subshape_broadcast_merge_shapes = + makeOP({Broadcast_1222}, {{"output_type", "i64"}}); + + // Extract separate subshape for data_1. + auto Constant_1226 = makeConst(element::i64, ov::Shape({ 1, }), {0}); - auto ReduceProd_1245 = makeOP({StridedSlice_1230, Constant_1244}, {{"keep_dims", true}}); - auto ReduceProd_1247 = makeOP({ShapeOf_1225, {0}}, {{"keep_dims", true}}); - auto Concat_1248 = makeOP({ReduceProd_1245, ReduceProd_1247}, {{"axis", 0}}); - auto Reshape_1249 = makeOP({Broadcast_1232, Concat_1248}, {{"special_zero", false}}); + auto Constant_1227 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto Constant_1229 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto separate1_subshape = + makeOP({ShapeOf_data_1, Constant_1226, Constant_1227, Constant_1229}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + + // Extract separate subshape for data_2. auto Constant_1235 = makeConst(element::i64, ov::Shape({ 1, @@ -381,35 +386,63 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { 1, }), {1}); - auto StridedSlice_1239 = - makeOP({ShapeOf_1207, Constant_1235, Constant_1236, Constant_1238}, + auto separate2_subshape = + makeOP({ShapeOf_data_2, Constant_1235, Constant_1236, Constant_1238}, {{"begin_mask", {0}}, {"end_mask", {0}}, {"new_axis_mask", {}}, {"shrink_axis_mask", {}}, {"ellipsis_mask", {}}}); - auto Concat_1240 = makeOP({StridedSlice_1239, ShapeOf_1225}, {{"axis", 0}}); - auto Broadcast_1241 = makeOP({Transpose_1205, Concat_1240}, {{"mode", "bidirectional"}}); + + // Broadcast data_1 and data_2 based on caluculated subshapes. + auto Concat_1231 = + makeOP({separate1_subshape, reduced_subshape_broadcast_merge_shapes}, {{"axis", 0}}); + auto Broadcast_data_1 = makeOP({data_1_processed, Concat_1231}, {{"mode", "bidirectional"}}); + auto Concat_1240 = + makeOP({separate2_subshape, reduced_subshape_broadcast_merge_shapes}, {{"axis", 0}}); + auto Broadcast_data_2 = makeOP({data_2_processed, Concat_1240}, {{"mode", "bidirectional"}}); + + // Optionally reshape broadcasted data_1 and data_2 so separate and reduced labels are represented by one + // dimension. Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. + // Reshape 1 + auto Constant_1244 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {0}); + auto Separate1_subshape_red = + makeOP({separate1_subshape, Constant_1244}, {{"keep_dims", true}}); + auto Reduced1_subshape_red = + makeOP({reduced_subshape_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); + auto reshape1_shape = makeOP({Separate1_subshape_red, Reduced1_subshape_red}, {{"axis", 0}}); + auto Reshape_1 = makeOP({Broadcast_data_1, reshape1_shape}, {{"special_zero", false}}); + + // Reshape 2 auto Constant_1302 = makeConst(element::i64, ov::Shape({ 1, }), {0}); - auto ReduceProd_1303 = makeOP({StridedSlice_1239, Constant_1302}, {{"keep_dims", true}}); - auto ReduceProd_1305 = makeOP({ShapeOf_1225, {0}}, {{"keep_dims", true}}); - auto Concat_1306 = makeOP({ReduceProd_1303, ReduceProd_1305}, {{"axis", 0}}); - auto Reshape_1307 = makeOP({Broadcast_1241, Concat_1306}, {{"special_zero", false}}); - auto MatMul_1360 = - makeOP({Reshape_1249, Reshape_1307}, {{"transpose_a", false}, {"transpose_b", true}}); - auto Concat_1361 = makeOP({StridedSlice_1230, StridedSlice_1239}, {{"axis", 0}}); - auto Reshape_1362 = makeOP({MatMul_1360, Concat_1361}, {{"special_zero", false}}); + auto Separate2_subshape_red = + makeOP({separate2_subshape, Constant_1302}, {{"keep_dims", true}}); + auto Reduced2_subshape_red = + makeOP({reduced_subshape_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); + auto reshape2_shape = makeOP({Separate2_subshape_red, Reduced2_subshape_red}, {{"axis", 0}}); + auto Reshape_2 = makeOP({Broadcast_data_2, reshape2_shape}, {{"special_zero", false}}); + // Apply MatMul operation for formatted inputs. + auto matmul = makeOP({Reshape_1, Reshape_2}, {{"transpose_a", false}, {"transpose_b", true}}); + // Optionally reshape back by unrolling dimensions corresponding to separate labels if needed. + // Target subshapes are calculated broadcast_merge_shapes function and concatenated. + auto reshape_outshape = makeOP({separate1_subshape, separate2_subshape}, {{"axis", 0}}); + auto reshape_out = makeOP({matmul, reshape_outshape}, {{"special_zero", false}}); + // Transpose to the original order of output labels. auto Constant_1363 = makeConst(element::i64, ov::Shape({ 3, }), {1, 0, 2}); - auto node_4 = makeOP({Reshape_1362, Constant_1363}); - model_ref = std::make_shared(NodeVector{node_4}, ParameterVector{node_2, node_0}); + auto transpose_out = makeOP({reshape_out, Constant_1363}); + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1, data_2}); } } From f4cb1b3e01c61b2c03501069b9d2ec0d988b1822 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 13:42:06 +0100 Subject: [PATCH 30/47] Improve redability of einsum decomposition test 4 Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition_test.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 5e0f1acf1cf74a..2e47bb83b6ce9b 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -457,8 +457,9 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { } { using namespace ov::gen_pattern; - auto node_0 = std::make_shared(element::f32, data_shape_1); - auto Multiply_1382 = makeConst( + auto data_1 = std::make_shared(element::f32, data_shape_1); + // If shapes are static, multi-identity can be constant-folded. + auto multi_identity = makeConst( element::f32, ov::Shape({ 1, @@ -469,20 +470,21 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { 1, }), {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); - auto Multiply_1383 = makeOP({node_0, Multiply_1382}, {{"auto_broadcast", "numpy"}}); + auto Multiply_1383 = makeOP({data_1, multi_identity}, {{"auto_broadcast", "numpy"}}); auto Constant_1384 = makeConst(element::i64, ov::Shape({ 3, }), {3, 4, 5}); - auto ReduceSum_1385 = makeOP({Multiply_1383, Constant_1384}, {{"keep_dims", false}}); + auto data_1_diagonal = makeOP({Multiply_1383, Constant_1384}, {{"keep_dims", false}}); + // Transpose to the original order of output labels. auto Constant_1386 = makeConst(element::i64, ov::Shape({ 3, }), {1, 2, 0}); - auto node_2 = makeOP({ReduceSum_1385, Constant_1386}); - model_ref = std::make_shared(NodeVector{node_2}, ParameterVector{node_0}); + auto transpose_out = makeOP({data_1_diagonal, Constant_1386}); + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1}); } } From 509e30255bf20f787d956a1ed8f8539ba5bca689 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 15:03:43 +0100 Subject: [PATCH 31/47] Improve redability for einsum decomposition test with duplicated label Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 140 ++++++++---------- 1 file changed, 59 insertions(+), 81 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 2e47bb83b6ce9b..8e54e11dbbb4ba 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -488,6 +488,47 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { } } +namespace { +using namespace ov::gen_pattern; +std::shared_ptr create_identity(const std::shared_ptr& data, + const std::vector& repated_label_indices) { + auto shapeof_data = makeOP({data}, {{"output_type", "i64"}}); + auto rankof_data = makeOP({shapeof_data}); + auto const_0 = makeConst(element::i64, ov::Shape({}), {0}); + auto const_1 = makeConst(element::i64, ov::Shape({}), {1}); + auto num_of_repeated_labels = makeConst(element::i64, ov::Shape({}), {repated_label_indices.size()}); + auto repeated_label_indices = makeConst(element::i64, + ov::Shape({ + repated_label_indices.size(), + }), + repated_label_indices); + auto repeated_dimensions = + makeOP({shapeof_data, repeated_label_indices, const_0}, {{"batch_dims", 0}}); + auto repeated_dimensions_size = makeOP({repeated_dimensions, const_0}, {{"keep_dims", true}}); + auto zeros_of_size = makeOP({const_0, repeated_dimensions_size}, {{"mode", "numpy"}}); + auto repeated_dimension = makeOP({repeated_dimensions, const_0, const_0}, {{"batch_dims", 0}}); + auto range_max_val = + makeOP({repeated_dimension, num_of_repeated_labels}, {{"auto_broadcast", "numpy"}}); + auto step_numerator = makeOP({range_max_val, const_1}, {{"auto_broadcast", "numpy"}}); + auto step_numerator_but_not_0 = makeOP({step_numerator, const_1}, {{"auto_broadcast", "numpy"}}); + auto step_denominator = makeOP({repeated_dimension, const_1}, {{"auto_broadcast", "numpy"}}); + auto step_denominator_but_not_0 = + makeOP({step_denominator, const_1}, {{"auto_broadcast", "numpy"}}); + auto step = makeOP({step_numerator_but_not_0, step_denominator_but_not_0}, + {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto eye_flattened_indices = makeOP({const_0, range_max_val, step}); + auto repeated_dimension_1d = makeOP({repeated_dimension, const_0}); + auto ones = makeOP({const_1, repeated_dimension_1d}, {{"mode", "numpy"}}); + auto eye_flattened = makeOP({zeros_of_size, eye_flattened_indices, ones, const_0}); + auto ones_of_input_shape_rank = makeOP({const_1, rankof_data}, {{"mode", "numpy"}}); + auto identity_shape = makeOP( + {ones_of_input_shape_rank, repeated_label_indices, repeated_dimensions, const_0}); + auto identity = makeOP({eye_flattened, identity_shape}, {{"special_zero", false}}); + return identity; +} + +} // namespace + TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) { PartialShape data_shape_1 = PartialShape::dynamic(5); { @@ -498,97 +539,34 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) } { using namespace ov::gen_pattern; - auto node_0 = std::make_shared(element::f32, data_shape_1); - auto Constant_2112 = makeConst(element::i64, ov::Shape({}), {0}); - auto ShapeOf_2109 = makeOP({node_0}, {{"output_type", "i64"}}); - auto Constant_2110 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {0, 2, 4}); - auto Gather_2114 = makeOP({ShapeOf_2109, Constant_2110, Constant_2112}, {{"batch_dims", 0}}); - auto ReduceProd_2526 = makeOP({Gather_2114, Constant_2112}, {{"keep_dims", true}}); - auto Constant_2527 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_2528 = - makeOP({Constant_2112, ReduceProd_2526, Constant_2527}, {{"mode", "numpy"}}); - auto Gather_2115 = makeOP({Gather_2114, Constant_2112, Constant_2112}, {{"batch_dims", 0}}); - auto Constant_2111 = makeConst(element::i64, ov::Shape({}), {3}); - auto Power_2116 = makeOP({Gather_2115, Constant_2111}, {{"auto_broadcast", "numpy"}}); - auto Constant_2113 = makeConst(element::i64, ov::Shape({}), {1}); - auto Subtract_2117 = makeOP({Power_2116, Constant_2113}, {{"auto_broadcast", "numpy"}}); - auto Maximum_2120 = makeOP({Subtract_2117, Constant_2113}, {{"auto_broadcast", "numpy"}}); - auto Subtract_2118 = makeOP({Gather_2115, Constant_2113}, {{"auto_broadcast", "numpy"}}); - auto Maximum_2119 = makeOP({Subtract_2118, Constant_2113}, {{"auto_broadcast", "numpy"}}); - auto Divide_2121 = - makeOP({Maximum_2120, Maximum_2119}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto Range_2122 = makeOP({Constant_2112, Power_2116, Divide_2121}); - auto Unsqueeze_2521 = makeOP({Gather_2115, Constant_2112}); - auto Constant_2522 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_2523 = - makeOP({Constant_2113, Unsqueeze_2521, Constant_2522}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_2557 = - makeOP({Broadcast_2528, Range_2122, Broadcast_2523, Constant_2112}); - auto ShapeOf_2558 = makeOP({ShapeOf_2109}); - auto Constant_2559 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_2560 = - makeOP({Constant_2113, ShapeOf_2558, Constant_2559}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_2563 = - makeOP({Broadcast_2560, Constant_2110, Gather_2114, Constant_2112}); - auto Reshape_2564 = makeOP({ScatterElementsUpdate_2557, ScatterElementsUpdate_2563}, - {{"special_zero", false}}); - auto Constant_2569 = makeConst(element::i64, ov::Shape({}), {0}); - auto ShapeOf_2566 = makeOP({node_0}, {{"output_type", "i64"}}); - auto Constant_2567 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {1, 3}); - auto Gather_2571 = makeOP({ShapeOf_2566, Constant_2567, Constant_2569}, {{"batch_dims", 0}}); - auto ReduceProd_2983 = makeOP({Gather_2571, Constant_2569}, {{"keep_dims", true}}); - auto Constant_2984 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_2985 = - makeOP({Constant_2569, ReduceProd_2983, Constant_2984}, {{"mode", "numpy"}}); - auto Gather_2572 = makeOP({Gather_2571, Constant_2569, Constant_2569}, {{"batch_dims", 0}}); - auto Constant_2568 = makeConst(element::i64, ov::Shape({}), {2}); - auto Power_2573 = makeOP({Gather_2572, Constant_2568}, {{"auto_broadcast", "numpy"}}); - auto Constant_2570 = makeConst(element::i64, ov::Shape({}), {1}); - auto Subtract_2574 = makeOP({Power_2573, Constant_2570}, {{"auto_broadcast", "numpy"}}); - auto Maximum_2577 = makeOP({Subtract_2574, Constant_2570}, {{"auto_broadcast", "numpy"}}); - auto Subtract_2575 = makeOP({Gather_2572, Constant_2570}, {{"auto_broadcast", "numpy"}}); - auto Maximum_2576 = makeOP({Subtract_2575, Constant_2570}, {{"auto_broadcast", "numpy"}}); - auto Divide_2578 = - makeOP({Maximum_2577, Maximum_2576}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto Range_2579 = makeOP({Constant_2569, Power_2573, Divide_2578}); - auto Unsqueeze_2978 = makeOP({Gather_2572, Constant_2569}); - auto Constant_2979 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_2980 = - makeOP({Constant_2570, Unsqueeze_2978, Constant_2979}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_3014 = - makeOP({Broadcast_2985, Range_2579, Broadcast_2980, Constant_2569}); - auto ShapeOf_3015 = makeOP({ShapeOf_2566}); - auto Constant_3016 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_3017 = - makeOP({Constant_2570, ShapeOf_3015, Constant_3016}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_3020 = - makeOP({Broadcast_3017, Constant_2567, Gather_2571, Constant_2569}); - auto Reshape_3021 = makeOP({ScatterElementsUpdate_3014, ScatterElementsUpdate_3020}, - {{"special_zero", false}}); - auto Multiply_3023 = makeOP({Reshape_2564, Reshape_3021}, {{"auto_broadcast", "numpy"}}); - auto ConvertLike_3024 = makeOP({Multiply_3023, node_0}); - auto Multiply_3024 = makeOP({node_0, ConvertLike_3024}, {{"auto_broadcast", "numpy"}}); + auto data_1 = std::make_shared(element::f32, data_shape_1); + + // Create identity for repated_label i + auto identity_i = create_identity(data_1, {0, 2, 4}); + // Create identity for repeated label j + auto identity_j = create_identity(data_1, {1, 3}); + + // Merge identities for all repeated labels to create multi-identity + auto multi_identity = makeOP({identity_i, identity_j}, {{"auto_broadcast", "numpy"}}); + + // Extract diagonals by multiplying by multi-identity and reducing + auto multi_identity_cvt = makeOP({multi_identity, data_1}); + auto Multiply_3024 = makeOP({data_1, multi_identity_cvt}, {{"auto_broadcast", "numpy"}}); auto Constant_3025 = makeConst(element::i64, ov::Shape({ 3, }), {2, 3, 4}); - auto ReduceSum_3026 = makeOP({Multiply_3024, Constant_3025}, {{"keep_dims", false}}); + auto data_1_diagonal = makeOP({Multiply_3024, Constant_3025}, {{"keep_dims", false}}); + + // Transpose to the original order of output labels. auto Constant_3027 = makeConst(element::i64, ov::Shape({ 2, }), {1, 0}); - auto node_2 = makeOP({ReduceSum_3026, Constant_3027}); - model_ref = std::make_shared(NodeVector{node_2}, ParameterVector{node_0}); + auto transpose_out = makeOP({data_1_diagonal, Constant_3027}); + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1}); } } From 09533e428c87ae1fc1f40335fb337f9f82da6d41 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 15:57:13 +0100 Subject: [PATCH 32/47] Extract subshape extraction to separate function for redability Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 308 +++++------------- 1 file changed, 84 insertions(+), 224 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 8e54e11dbbb4ba..5e84e9da19164e 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -12,7 +12,72 @@ #include "transformations/utils/gen_pattern.hpp" using namespace ov; +namespace { +using namespace ov::gen_pattern; +std::shared_ptr extract_subshape_from_shape(const std::shared_ptr& shape_node, + size_t begin, + size_t end) { + auto const_begin = makeConst(element::i64, + ov::Shape({ + 1, + }), + {begin}); + auto const_end = makeConst(element::i64, + ov::Shape({ + 1, + }), + {end}); + auto const_1 = makeConst(element::i64, + ov::Shape({ + 1, + }), + {1}); + auto subshape = makeOP({shape_node, const_begin, const_end, const_1}, + {{"begin_mask", {0}}, + {"end_mask", {0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + return subshape; +} +std::shared_ptr create_identity(const std::shared_ptr& data, + const std::vector& repated_label_indices) { + auto shapeof_data = makeOP({data}, {{"output_type", "i64"}}); + auto rankof_data = makeOP({shapeof_data}); + auto const_0 = makeConst(element::i64, ov::Shape({}), {0}); + auto const_1 = makeConst(element::i64, ov::Shape({}), {1}); + auto num_of_repeated_labels = makeConst(element::i64, ov::Shape({}), {repated_label_indices.size()}); + auto repeated_label_indices = makeConst(element::i64, + ov::Shape({ + repated_label_indices.size(), + }), + repated_label_indices); + auto repeated_dimensions = + makeOP({shapeof_data, repeated_label_indices, const_0}, {{"batch_dims", 0}}); + auto repeated_dimensions_size = makeOP({repeated_dimensions, const_0}, {{"keep_dims", true}}); + auto zeros_of_size = makeOP({const_0, repeated_dimensions_size}, {{"mode", "numpy"}}); + auto repeated_dimension = makeOP({repeated_dimensions, const_0, const_0}, {{"batch_dims", 0}}); + auto range_max_val = + makeOP({repeated_dimension, num_of_repeated_labels}, {{"auto_broadcast", "numpy"}}); + auto step_numerator = makeOP({range_max_val, const_1}, {{"auto_broadcast", "numpy"}}); + auto step_numerator_but_not_0 = makeOP({step_numerator, const_1}, {{"auto_broadcast", "numpy"}}); + auto step_denominator = makeOP({repeated_dimension, const_1}, {{"auto_broadcast", "numpy"}}); + auto step_denominator_but_not_0 = + makeOP({step_denominator, const_1}, {{"auto_broadcast", "numpy"}}); + auto step = makeOP({step_numerator_but_not_0, step_denominator_but_not_0}, + {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); + auto eye_flattened_indices = makeOP({const_0, range_max_val, step}); + auto repeated_dimension_1d = makeOP({repeated_dimension, const_0}); + auto ones = makeOP({const_1, repeated_dimension_1d}, {{"mode", "numpy"}}); + auto eye_flattened = makeOP({zeros_of_size, eye_flattened_indices, ones, const_0}); + auto ones_of_input_shape_rank = makeOP({const_1, rankof_data}, {{"mode", "numpy"}}); + auto identity_shape = makeOP( + {ones_of_input_shape_rank, repeated_label_indices, repeated_dimensions, const_0}); + auto identity = makeOP({eye_flattened, identity_shape}, {{"special_zero", false}}); + return identity; +} +} // namespace TEST_F(TransformationTestsF, Einsum_2in_matmul) { PartialShape data_shape_1{5, 2}; PartialShape data_shape_2{10, 1, 25}; @@ -89,108 +154,27 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { auto ShapeOf_data_2 = makeOP({Transpose_486}, {{"output_type", "i64"}}); // Get reduced subshape for data_1. - auto Constant_489 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_490 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_492 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto reduced1 = makeOP({ShapeOf_data_1, Constant_489, Constant_490, Constant_492}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto reduced1 = extract_subshape_from_shape(ShapeOf_data_1, 1, 2); + // Get reduced subshape for data_2. - auto Constant_494 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_495 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {3}); - auto Constant_497 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto reduced_2 = makeOP({ShapeOf_data_2, Constant_494, Constant_495, Constant_497}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto reduced2 = extract_subshape_from_shape(ShapeOf_data_2, 2, 3); - // broadcast_merge_shapes(reduced1, reduced_2) + // broadcast_merge_shapes(reduced1, reduced2) auto Constant_499 = makeConst(element::i64, ov::Shape({ 1, }), {1}); auto Broadcast_500 = makeOP({Constant_499, reduced1}, {{"mode", "numpy"}}); - auto Broadcast_503 = makeOP({Broadcast_500, reduced_2}, {{"mode", "bidirectional"}}); + auto Broadcast_503 = makeOP({Broadcast_500, reduced2}, {{"mode", "bidirectional"}}); auto reduced_subshape_broadcast_merge_shapes = makeOP({Broadcast_503}, {{"output_type", "i64"}}); // Extract separate subshape for data_1. - auto Constant_507 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_508 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_510 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto separate1_subshape = - makeOP({ShapeOf_data_1, Constant_507, Constant_508, Constant_510}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto separate1_subshape = extract_subshape_from_shape(ShapeOf_data_1, 0, 1); // Extract separate subshape for data_2. - auto Constant_516 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_517 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_519 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto separate2_subshape = - makeOP({ShapeOf_data_2, Constant_516, Constant_517, Constant_519}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto separate2_subshape = extract_subshape_from_shape(ShapeOf_data_2, 0, 2); // Broadcast data_1 and data_2 based on caluculated subshapes. auto Concat_512 = @@ -211,10 +195,10 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { // Reduce separate and reduced auto Separate1_subshape_red = makeOP({separate1_subshape, Constant_525}, {{"keep_dims", true}}); - auto Reduced1_subshape_red = + auto reduced1_subshape_red = makeOP({reduced_subshape_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); // Merge subshapes - auto reshape_subshape1 = makeOP({Separate1_subshape_red, Reduced1_subshape_red}, {{"axis", 0}}); + auto reshape_subshape1 = makeOP({Separate1_subshape_red, reduced1_subshape_red}, {{"axis", 0}}); auto Reshape_1 = makeOP({Broadcast_data_1, reshape_subshape1}, {{"special_zero", false}}); // Reshape 2 auto Constant_569 = makeConst(element::i64, @@ -289,110 +273,27 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { auto ShapeOf_data_2 = makeOP({data_2_processed}, {{"output_type", "i64"}}); // Get reduced subshape for data_1. - auto Constant_1208 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_1209 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {4}); - auto Constant_1211 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1212 = - makeOP({ShapeOf_data_1, Constant_1208, Constant_1209, Constant_1211}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto reduced1 = extract_subshape_from_shape(ShapeOf_data_1, 1, 4); + // Get reduced subshape for data_2. - auto Constant_1213 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_1214 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {5}); - auto Constant_1216 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1217 = - makeOP({ShapeOf_data_2, Constant_1213, Constant_1214, Constant_1216}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto reduced2 = extract_subshape_from_shape(ShapeOf_data_2, 2, 5); + // broadcast_merge_shapes(reduced1, reduced_2) auto Constant_1218 = makeConst(element::i64, ov::Shape({ 1, }), {1}); - auto Broadcast_1219 = makeOP({Constant_1218, StridedSlice_1212}, {{"mode", "numpy"}}); - auto Broadcast_1222 = - makeOP({Broadcast_1219, StridedSlice_1217}, {{"mode", "bidirectional"}}); + auto Broadcast_1219 = makeOP({Constant_1218, reduced1}, {{"mode", "numpy"}}); + auto Broadcast_1222 = makeOP({Broadcast_1219, reduced2}, {{"mode", "bidirectional"}}); auto reduced_subshape_broadcast_merge_shapes = makeOP({Broadcast_1222}, {{"output_type", "i64"}}); // Extract separate subshape for data_1. - auto Constant_1226 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_1227 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_1229 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto separate1_subshape = - makeOP({ShapeOf_data_1, Constant_1226, Constant_1227, Constant_1229}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto separate1_subshape = extract_subshape_from_shape(ShapeOf_data_1, 0, 1); // Extract separate subshape for data_2. - auto Constant_1235 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_1236 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_1238 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto separate2_subshape = - makeOP({ShapeOf_data_2, Constant_1235, Constant_1236, Constant_1238}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); + auto separate2_subshape = extract_subshape_from_shape(ShapeOf_data_2, 0, 2); // Broadcast data_1 and data_2 based on caluculated subshapes. auto Concat_1231 = @@ -412,9 +313,9 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { {0}); auto Separate1_subshape_red = makeOP({separate1_subshape, Constant_1244}, {{"keep_dims", true}}); - auto Reduced1_subshape_red = + auto reduced1_subshape_red = makeOP({reduced_subshape_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); - auto reshape1_shape = makeOP({Separate1_subshape_red, Reduced1_subshape_red}, {{"axis", 0}}); + auto reshape1_shape = makeOP({Separate1_subshape_red, reduced1_subshape_red}, {{"axis", 0}}); auto Reshape_1 = makeOP({Broadcast_data_1, reshape1_shape}, {{"special_zero", false}}); // Reshape 2 @@ -488,47 +389,6 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { } } -namespace { -using namespace ov::gen_pattern; -std::shared_ptr create_identity(const std::shared_ptr& data, - const std::vector& repated_label_indices) { - auto shapeof_data = makeOP({data}, {{"output_type", "i64"}}); - auto rankof_data = makeOP({shapeof_data}); - auto const_0 = makeConst(element::i64, ov::Shape({}), {0}); - auto const_1 = makeConst(element::i64, ov::Shape({}), {1}); - auto num_of_repeated_labels = makeConst(element::i64, ov::Shape({}), {repated_label_indices.size()}); - auto repeated_label_indices = makeConst(element::i64, - ov::Shape({ - repated_label_indices.size(), - }), - repated_label_indices); - auto repeated_dimensions = - makeOP({shapeof_data, repeated_label_indices, const_0}, {{"batch_dims", 0}}); - auto repeated_dimensions_size = makeOP({repeated_dimensions, const_0}, {{"keep_dims", true}}); - auto zeros_of_size = makeOP({const_0, repeated_dimensions_size}, {{"mode", "numpy"}}); - auto repeated_dimension = makeOP({repeated_dimensions, const_0, const_0}, {{"batch_dims", 0}}); - auto range_max_val = - makeOP({repeated_dimension, num_of_repeated_labels}, {{"auto_broadcast", "numpy"}}); - auto step_numerator = makeOP({range_max_val, const_1}, {{"auto_broadcast", "numpy"}}); - auto step_numerator_but_not_0 = makeOP({step_numerator, const_1}, {{"auto_broadcast", "numpy"}}); - auto step_denominator = makeOP({repeated_dimension, const_1}, {{"auto_broadcast", "numpy"}}); - auto step_denominator_but_not_0 = - makeOP({step_denominator, const_1}, {{"auto_broadcast", "numpy"}}); - auto step = makeOP({step_numerator_but_not_0, step_denominator_but_not_0}, - {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto eye_flattened_indices = makeOP({const_0, range_max_val, step}); - auto repeated_dimension_1d = makeOP({repeated_dimension, const_0}); - auto ones = makeOP({const_1, repeated_dimension_1d}, {{"mode", "numpy"}}); - auto eye_flattened = makeOP({zeros_of_size, eye_flattened_indices, ones, const_0}); - auto ones_of_input_shape_rank = makeOP({const_1, rankof_data}, {{"mode", "numpy"}}); - auto identity_shape = makeOP( - {ones_of_input_shape_rank, repeated_label_indices, repeated_dimensions, const_0}); - auto identity = makeOP({eye_flattened, identity_shape}, {{"special_zero", false}}); - return identity; -} - -} // namespace - TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) { PartialShape data_shape_1 = PartialShape::dynamic(5); { From 83e0317fb3886d908e84790859b406fd47346803 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 16:13:24 +0100 Subject: [PATCH 33/47] Extract broadcast_merge_shapes to separate function Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 49 ++++++------------- 1 file changed, 16 insertions(+), 33 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 5e84e9da19164e..44bd799526d71a 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -17,21 +17,9 @@ using namespace ov::gen_pattern; std::shared_ptr extract_subshape_from_shape(const std::shared_ptr& shape_node, size_t begin, size_t end) { - auto const_begin = makeConst(element::i64, - ov::Shape({ - 1, - }), - {begin}); - auto const_end = makeConst(element::i64, - ov::Shape({ - 1, - }), - {end}); - auto const_1 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); + auto const_begin = makeConst(element::i64, ov::Shape({1}), {begin}); + auto const_end = makeConst(element::i64, ov::Shape({1}), {end}); + auto const_1 = makeConst(element::i64, ov::Shape({1}), {1}); auto subshape = makeOP({shape_node, const_begin, const_end, const_1}, {{"begin_mask", {0}}, {"end_mask", {0}}, @@ -40,6 +28,17 @@ std::shared_ptr extract_subshape_from_shape(const std::shared_ptr broadcast_merge_shapes(const std::shared_ptr& shape_node_lhs, + const std::shared_ptr& shape_node_rhs) { + auto const_1 = makeConst(element::i64, ov::Shape({1}), {1}); + auto tensor_of_lhs_shape = makeOP({const_1, shape_node_lhs}, {{"mode", "numpy"}}); + auto tensor_of_broadcasted_lhs_rhs_shape = + makeOP({tensor_of_lhs_shape, shape_node_rhs}, {{"mode", "bidirectional"}}); + auto broadcasted_shapes = makeOP({tensor_of_broadcasted_lhs_rhs_shape}, {{"output_type", "i64"}}); + return broadcasted_shapes; +} + std::shared_ptr create_identity(const std::shared_ptr& data, const std::vector& repated_label_indices) { auto shapeof_data = makeOP({data}, {{"output_type", "i64"}}); @@ -160,15 +159,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { auto reduced2 = extract_subshape_from_shape(ShapeOf_data_2, 2, 3); // broadcast_merge_shapes(reduced1, reduced2) - auto Constant_499 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Broadcast_500 = makeOP({Constant_499, reduced1}, {{"mode", "numpy"}}); - auto Broadcast_503 = makeOP({Broadcast_500, reduced2}, {{"mode", "bidirectional"}}); - auto reduced_subshape_broadcast_merge_shapes = - makeOP({Broadcast_503}, {{"output_type", "i64"}}); + auto reduced_subshape_broadcast_merge_shapes = broadcast_merge_shapes(reduced1, reduced2); // Extract separate subshape for data_1. auto separate1_subshape = extract_subshape_from_shape(ShapeOf_data_1, 0, 1); @@ -279,15 +270,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { auto reduced2 = extract_subshape_from_shape(ShapeOf_data_2, 2, 5); // broadcast_merge_shapes(reduced1, reduced_2) - auto Constant_1218 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Broadcast_1219 = makeOP({Constant_1218, reduced1}, {{"mode", "numpy"}}); - auto Broadcast_1222 = makeOP({Broadcast_1219, reduced2}, {{"mode", "bidirectional"}}); - auto reduced_subshape_broadcast_merge_shapes = - makeOP({Broadcast_1222}, {{"output_type", "i64"}}); + auto reduced_subshape_broadcast_merge_shapes = broadcast_merge_shapes(reduced1, reduced2); // Extract separate subshape for data_1. auto separate1_subshape = extract_subshape_from_shape(ShapeOf_data_1, 0, 1); From a5e46dda9d0c625b577febc63c89f5ec562e7142 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 17:28:47 +0100 Subject: [PATCH 34/47] Add helper for diagonal extraction for redability Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 44bd799526d71a..7494961d2a5298 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -76,6 +76,28 @@ std::shared_ptr create_identity(const std::shared_ptr& data, return identity; } +std::shared_ptr extract_diagonal(const std::shared_ptr& data, + const std::vector>& indices_of_repeated_labels) { + // Initialize multi_identity by identity for first repeated label. + auto multi_identity = create_identity(data, indices_of_repeated_labels[0]); + // Initialize reduction axes by all except first repated_label_indices for first repeated label. + std::vector reduce_axes(indices_of_repeated_labels[0].begin() + 1, indices_of_repeated_labels[0].end()); + // Merge remaining identities. + for (size_t i = 1; i < indices_of_repeated_labels.size(); i++) { + auto identity = create_identity(data, indices_of_repeated_labels[i]); + multi_identity = makeOP({multi_identity, identity}, {{"auto_broadcast", "numpy"}}); + reduce_axes.insert(reduce_axes.end(), + indices_of_repeated_labels[i].begin() + 1, + indices_of_repeated_labels[i].end()); + } + // Convert to match type of data + auto multi_identity_cvt = makeOP({multi_identity, data}); + auto unreduced_diagonal = makeOP({data, multi_identity_cvt}, {{"auto_broadcast", "numpy"}}); + auto const_reduce_axes = makeConst(element::i64, ov::Shape({reduce_axes.size()}), reduce_axes); + auto diagonal = makeOP({unreduced_diagonal, const_reduce_axes}, {{"keep_dims", false}}); + return diagonal; +} + } // namespace TEST_F(TransformationTestsF, Einsum_2in_matmul) { PartialShape data_shape_1{5, 2}; @@ -384,23 +406,12 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) using namespace ov::gen_pattern; auto data_1 = std::make_shared(element::f32, data_shape_1); - // Create identity for repated_label i - auto identity_i = create_identity(data_1, {0, 2, 4}); - // Create identity for repeated label j - auto identity_j = create_identity(data_1, {1, 3}); - - // Merge identities for all repeated labels to create multi-identity - auto multi_identity = makeOP({identity_i, identity_j}, {{"auto_broadcast", "numpy"}}); - - // Extract diagonals by multiplying by multi-identity and reducing - auto multi_identity_cvt = makeOP({multi_identity, data_1}); - auto Multiply_3024 = makeOP({data_1, multi_identity_cvt}, {{"auto_broadcast", "numpy"}}); - auto Constant_3025 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {2, 3, 4}); - auto data_1_diagonal = makeOP({Multiply_3024, Constant_3025}, {{"keep_dims", false}}); + // Extract diagonal + auto data_1_diagonal = extract_diagonal(data_1, + { + {0, 2, 4}, // indices of repeated label i + {1, 3}, // indices of repeated label j + }); // Transpose to the original order of output labels. auto Constant_3027 = makeConst(element::i64, From 6c89a9e6cc06ad5c543a21bebfa800f93aaec214 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 17:40:18 +0100 Subject: [PATCH 35/47] Fix const formatting Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 207 +++--------------- 1 file changed, 35 insertions(+), 172 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 7494961d2a5298..6bcb8ff5eb35ef 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -164,11 +164,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. - auto Constant_485 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {0, 2, 1}); + auto Constant_485 = makeConst(element::i64, ov::Shape({3}), {0, 2, 1}); auto Transpose_486 = makeOP({data_2, Constant_485}); // Get shapes of data_1 and data_2. auto ShapeOf_data_1 = makeOP({data_1}, {{"output_type", "i64"}}); @@ -200,11 +196,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { // Optionally reshape broadcasted data_1 and data_2 so separate and reduced labels are represented by one // dimension. Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. // Reshape 1 - auto Constant_525 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); + auto Constant_525 = makeConst(element::i64, ov::Shape({1}), {0}); // Reduce separate and reduced auto Separate1_subshape_red = makeOP({separate1_subshape, Constant_525}, {{"keep_dims", true}}); @@ -214,11 +206,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { auto reshape_subshape1 = makeOP({Separate1_subshape_red, reduced1_subshape_red}, {{"axis", 0}}); auto Reshape_1 = makeOP({Broadcast_data_1, reshape_subshape1}, {{"special_zero", false}}); // Reshape 2 - auto Constant_569 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); + auto Constant_569 = makeConst(element::i64, ov::Shape({1}), {0}); // Reduce separate and reduced auto Separate2_subshape_red = makeOP({separate2_subshape, Constant_569}, {{"keep_dims", true}}); @@ -259,26 +247,14 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { auto data_2 = std::make_shared(element::f32, data_shape_2); // Process data_1 // data_1 contains no dimensions at ellipsis label, unsqueeze to allow for broadcasting - auto Constant_1200 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); + auto Constant_1200 = makeConst(element::i64, ov::Shape({1}), {2}); auto Unsqueeze_1201 = makeOP({data_1, Constant_1200}); // Match ranks of dimensions covered by ellipsis labels - auto Constant_1202 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); + auto Constant_1202 = makeConst(element::i64, ov::Shape({1}), {2}); auto data_1_processed = makeOP({Unsqueeze_1201, Constant_1202}); // Process data_2 // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. - auto Constant_1204 = makeConst(element::i64, - ov::Shape({ - 5, - }), - {0, 4, 3, 1, 2}); + auto Constant_1204 = makeConst(element::i64, ov::Shape({5}), {0, 4, 3, 1, 2}); auto data_2_processed = makeOP({data_2, Constant_1204}); // Get shapes for data_1 and data_2 @@ -311,11 +287,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { // Optionally reshape broadcasted data_1 and data_2 so separate and reduced labels are represented by one // dimension. Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. // Reshape 1 - auto Constant_1244 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); + auto Constant_1244 = makeConst(element::i64, ov::Shape({1}), {0}); auto Separate1_subshape_red = makeOP({separate1_subshape, Constant_1244}, {{"keep_dims", true}}); auto reduced1_subshape_red = @@ -324,11 +296,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { auto Reshape_1 = makeOP({Broadcast_data_1, reshape1_shape}, {{"special_zero", false}}); // Reshape 2 - auto Constant_1302 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); + auto Constant_1302 = makeConst(element::i64, ov::Shape({1}), {0}); auto Separate2_subshape_red = makeOP({separate2_subshape, Constant_1302}, {{"keep_dims", true}}); auto Reduced2_subshape_red = @@ -342,11 +310,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { auto reshape_outshape = makeOP({separate1_subshape, separate2_subshape}, {{"axis", 0}}); auto reshape_out = makeOP({matmul, reshape_outshape}, {{"special_zero", false}}); // Transpose to the original order of output labels. - auto Constant_1363 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {1, 0, 2}); + auto Constant_1363 = makeConst(element::i64, ov::Shape({3}), {1, 0, 2}); auto transpose_out = makeOP({reshape_out, Constant_1363}); model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1, data_2}); } @@ -367,28 +331,13 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { // If shapes are static, multi-identity can be constant-folded. auto multi_identity = makeConst( element::f32, - ov::Shape({ - 1, - 3, - 1, - 1, - 3, - 1, - }), + ov::Shape({1, 3, 1, 1, 3, 1}), {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); auto Multiply_1383 = makeOP({data_1, multi_identity}, {{"auto_broadcast", "numpy"}}); - auto Constant_1384 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {3, 4, 5}); + auto Constant_1384 = makeConst(element::i64, ov::Shape({3}), {3, 4, 5}); auto data_1_diagonal = makeOP({Multiply_1383, Constant_1384}, {{"keep_dims", false}}); // Transpose to the original order of output labels. - auto Constant_1386 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {1, 2, 0}); + auto Constant_1386 = makeConst(element::i64, ov::Shape({3}), {1, 2, 0}); auto transpose_out = makeOP({data_1_diagonal, Constant_1386}); model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1}); } @@ -414,11 +363,7 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) }); // Transpose to the original order of output labels. - auto Constant_3027 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {1, 0}); + auto Constant_3027 = makeConst(element::i64, ov::Shape({2}), {1, 0}); auto transpose_out = makeOP({data_1_diagonal, Constant_3027}); model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1}); } @@ -443,133 +388,51 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_s auto node_0 = std::make_shared(element::f32, data_shape_3); auto node_2 = std::make_shared(element::f32, data_shape_2); auto node_4 = std::make_shared(element::f32, data_shape_1); - auto Multiply_1990 = makeConst(element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - 1, - 1, - }), - {1.000000f}); + // ConstantFold folded multi-identity for input 2 to single constant + auto Multiply_1990 = makeConst(element::f32, ov::Shape({1, 1, 1, 1, 1, 1}), {1.000000f}); + // Extract diagonals auto Multiply_1991 = makeOP({node_2, Multiply_1990}, {{"auto_broadcast", "numpy"}}); - auto Constant_1992 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {2, 3, 5}); + auto Constant_1992 = makeConst(element::i64, ov::Shape({3}), {2, 3, 5}); auto ReduceSum_1993 = makeOP({Multiply_1991, Constant_1992}, {{"keep_dims", false}}); - auto Concat_2034 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {4, 3, 3}); + // Broadcast for ellipsis and labels constant folded to single constant and broadcast + auto Concat_2034 = makeConst(element::i64, ov::Shape({3}), {4, 3, 3}); + // Broadcast ellipsis and labels auto Broadcast_2035 = makeOP({ReduceSum_1993, Concat_2034}, {{"mode", "bidirectional"}}); - auto Concat_2051 = makeConst(element::i64, - ov::Shape({ - 4, - }), - {4, 3, 3, 1}); + auto Concat_2051 = makeConst(element::i64, ov::Shape({4}), {4, 3, 3, 1}); auto Reshape_2052 = makeOP({Broadcast_2035, Concat_2051}, {{"special_zero", false}}); - auto Convert_1700 = makeConst(element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - 1, - 1, - }), - {1.000000f}); + auto Convert_1700 = makeConst(element::f32, ov::Shape({1, 1, 1, 1, 1, 1}), {1.000000f}); auto Multiply_1701 = makeOP({node_4, Convert_1700}, {{"auto_broadcast", "numpy"}}); - auto Constant_1702 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {5}); + auto Constant_1702 = makeConst(element::i64, ov::Shape({1}), {5}); auto ReduceSum_1703 = makeOP({Multiply_1701, Constant_1702}, {{"keep_dims", false}}); - auto Constant_1799 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); + auto Constant_1799 = makeConst(element::i64, ov::Shape({1}), {1}); auto ReduceSum_1800 = makeOP({ReduceSum_1703, Constant_1799}, {{"keep_dims", false}}); - auto Constant_1803 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {4, 5}); + auto Constant_1803 = makeConst(element::i64, ov::Shape({2}), {4, 5}); auto Unsqueeze_1804 = makeOP({ReduceSum_1800, Constant_1803}); - auto Constant_1605 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); + auto Constant_1605 = makeConst(element::i64, ov::Shape({1}), {0}); auto Unsqueeze_1606 = makeOP({node_0, Constant_1605}); - auto Constant_1607 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {0, 1}); + auto Constant_1607 = makeConst(element::i64, ov::Shape({2}), {0, 1}); auto Unsqueeze_1608 = makeOP({Unsqueeze_1606, Constant_1607}); auto Convert_1795 = makeConst( element::f32, - ov::Shape({ - 1, - 1, - 1, - 1, - 1, - 3, - 3, - }), + ov::Shape({1, 1, 1, 1, 1, 3, 3}), {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); auto Multiply_1796 = makeOP({Unsqueeze_1608, Convert_1795}, {{"auto_broadcast", "numpy"}}); - auto Constant_1797 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {6}); + auto Constant_1797 = makeConst(element::i64, ov::Shape({1}), {6}); auto ReduceSum_1798 = makeOP({Multiply_1796, Constant_1797}, {{"keep_dims", false}}); - auto Constant_1801 = makeConst(element::i64, - ov::Shape({ - 6, - }), - {4, 0, 1, 2, 3, 5}); + auto Constant_1801 = makeConst(element::i64, ov::Shape({6}), {4, 0, 1, 2, 3, 5}); auto Transpose_1802 = makeOP({ReduceSum_1798, Constant_1801}); auto Multiply_1805 = makeOP({Unsqueeze_1804, Transpose_1802}, {{"auto_broadcast", "numpy"}}); - auto Constant_1994 = makeConst(element::i64, - ov::Shape({ - 6, - }), - {0, 5, 1, 2, 3, 4}); + auto Constant_1994 = makeConst(element::i64, ov::Shape({6}), {0, 5, 1, 2, 3, 4}); auto Transpose_1995 = makeOP({Multiply_1805, Constant_1994}); - auto Concat_2043 = makeConst(element::i64, - ov::Shape({ - 6, - }), - {4, 3, 2, 1, 1, 3}); + auto Concat_2043 = makeConst(element::i64, ov::Shape({6}), {4, 3, 2, 1, 1, 3}); auto Broadcast_2044 = makeOP({Transpose_1995, Concat_2043}, {{"mode", "bidirectional"}}); - auto Concat_2076 = makeConst(element::i64, - ov::Shape({ - 4, - }), - {4, 3, 2, 3}); + auto Concat_2076 = makeConst(element::i64, ov::Shape({4}), {4, 3, 2, 3}); auto Reshape_2077 = makeOP({Broadcast_2044, Concat_2076}, {{"special_zero", false}}); auto MatMul_2116 = makeOP({Reshape_2052, Reshape_2077}, {{"transpose_a", true}, {"transpose_b", true}}); - auto Concat_2117 = makeConst(element::i64, - ov::Shape({ - 5, - }), - {4, 3, 2, 1, 1}); + auto Concat_2117 = makeConst(element::i64, ov::Shape({5}), {4, 3, 2, 1, 1}); auto Reshape_2118 = makeOP({MatMul_2116, Concat_2117}, {{"special_zero", false}}); - auto Constant_2119 = makeConst(element::i64, - ov::Shape({ - 5, - }), - {1, 2, 3, 4, 0}); + auto Constant_2119 = makeConst(element::i64, ov::Shape({5}), {1, 2, 3, 4, 0}); auto node_6 = makeOP({Reshape_2118, Constant_2119}); model_ref = std::make_shared(NodeVector{node_6}, ParameterVector{node_4, node_2, node_0}); } From 3cf8b8320e443e626a48070e29b342afbc674123 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 12 Feb 2025 19:18:55 +0100 Subject: [PATCH 36/47] Improve redability for einsum decomposition test Signed-off-by: Mateusz Mikolajczyk --- .../einsum_decomposition_test.cpp | 421 ++++-------------- 1 file changed, 92 insertions(+), 329 deletions(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 6bcb8ff5eb35ef..177f2c2ba46f3c 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -453,334 +453,97 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d } { using namespace ov::gen_pattern; - auto node_0 = std::make_shared(element::f32, data_shape_3); - auto node_2 = std::make_shared(element::f32, data_shape_2); - auto node_4 = std::make_shared(element::f32, data_shape_1); - auto Constant_904 = makeConst(element::i64, ov::Shape({}), {0}); - auto ShapeOf_901 = makeOP({node_2}, {{"output_type", "i64"}}); - auto Constant_902 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {1, 2, 3}); - auto Gather_906 = makeOP({ShapeOf_901, Constant_902, Constant_904}, {{"batch_dims", 0}}); - auto ReduceProd_1318 = makeOP({Gather_906, Constant_904}, {{"keep_dims", true}}); - auto Constant_1319 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_1320 = - makeOP({Constant_904, ReduceProd_1318, Constant_1319}, {{"mode", "numpy"}}); - auto Gather_907 = makeOP({Gather_906, Constant_904, Constant_904}, {{"batch_dims", 0}}); - auto Constant_903 = makeConst(element::i64, ov::Shape({}), {3}); - auto Power_908 = makeOP({Gather_907, Constant_903}, {{"auto_broadcast", "numpy"}}); - auto Constant_905 = makeConst(element::i64, ov::Shape({}), {1}); - auto Subtract_909 = makeOP({Power_908, Constant_905}, {{"auto_broadcast", "numpy"}}); - auto Maximum_912 = makeOP({Subtract_909, Constant_905}, {{"auto_broadcast", "numpy"}}); - auto Subtract_910 = makeOP({Gather_907, Constant_905}, {{"auto_broadcast", "numpy"}}); - auto Maximum_911 = makeOP({Subtract_910, Constant_905}, {{"auto_broadcast", "numpy"}}); - auto Divide_913 = - makeOP({Maximum_912, Maximum_911}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto Range_914 = makeOP({Constant_904, Power_908, Divide_913}); - auto Unsqueeze_1313 = makeOP({Gather_907, Constant_904}); - auto Constant_1314 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_1315 = - makeOP({Constant_905, Unsqueeze_1313, Constant_1314}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_1349 = - makeOP({Broadcast_1320, Range_914, Broadcast_1315, Constant_904}); - auto ShapeOf_1350 = makeOP({ShapeOf_901}); - auto Constant_1351 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_1352 = - makeOP({Constant_905, ShapeOf_1350, Constant_1351}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_1355 = - makeOP({Broadcast_1352, Constant_902, Gather_906, Constant_904}); - auto Reshape_1356 = makeOP({ScatterElementsUpdate_1349, ScatterElementsUpdate_1355}, - {{"special_zero", false}}); - auto Constant_1361 = makeConst(element::i64, ov::Shape({}), {0}); - auto ShapeOf_1358 = makeOP({node_2}, {{"output_type", "i64"}}); - auto Constant_1359 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {4, 5}); - auto Gather_1363 = makeOP({ShapeOf_1358, Constant_1359, Constant_1361}, {{"batch_dims", 0}}); - auto ReduceProd_1775 = makeOP({Gather_1363, Constant_1361}, {{"keep_dims", true}}); - auto Constant_1776 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_1777 = - makeOP({Constant_1361, ReduceProd_1775, Constant_1776}, {{"mode", "numpy"}}); - auto Gather_1364 = makeOP({Gather_1363, Constant_1361, Constant_1361}, {{"batch_dims", 0}}); - auto Constant_1360 = makeConst(element::i64, ov::Shape({}), {2}); - auto Power_1365 = makeOP({Gather_1364, Constant_1360}, {{"auto_broadcast", "numpy"}}); - auto Constant_1362 = makeConst(element::i64, ov::Shape({}), {1}); - auto Subtract_1366 = makeOP({Power_1365, Constant_1362}, {{"auto_broadcast", "numpy"}}); - auto Maximum_1369 = makeOP({Subtract_1366, Constant_1362}, {{"auto_broadcast", "numpy"}}); - auto Subtract_1367 = makeOP({Gather_1364, Constant_1362}, {{"auto_broadcast", "numpy"}}); - auto Maximum_1368 = makeOP({Subtract_1367, Constant_1362}, {{"auto_broadcast", "numpy"}}); - auto Divide_1370 = - makeOP({Maximum_1369, Maximum_1368}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto Range_1371 = makeOP({Constant_1361, Power_1365, Divide_1370}); - auto Unsqueeze_1770 = makeOP({Gather_1364, Constant_1361}); - auto Constant_1771 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_1772 = - makeOP({Constant_1362, Unsqueeze_1770, Constant_1771}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_1806 = - makeOP({Broadcast_1777, Range_1371, Broadcast_1772, Constant_1361}); - auto ShapeOf_1807 = makeOP({ShapeOf_1358}); - auto Constant_1808 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_1809 = - makeOP({Constant_1362, ShapeOf_1807, Constant_1808}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_1812 = - makeOP({Broadcast_1809, Constant_1359, Gather_1363, Constant_1361}); - auto Reshape_1813 = makeOP({ScatterElementsUpdate_1806, ScatterElementsUpdate_1812}, - {{"special_zero", false}}); - auto Multiply_1815 = makeOP({Reshape_1356, Reshape_1813}, {{"auto_broadcast", "numpy"}}); - auto ConvertLike_1816 = makeOP({Multiply_1815, node_2}); - auto Multiply_1816 = makeOP({node_2, ConvertLike_1816}, {{"auto_broadcast", "numpy"}}); - auto Constant_1817 = makeConst(element::i64, - ov::Shape({ - 3, - }), - {2, 3, 5}); - auto ReduceSum_1818 = makeOP({Multiply_1816, Constant_1817}, {{"keep_dims", false}}); - auto Constant_1833 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto ShapeOf_1821 = makeOP({ReduceSum_1818}, {{"output_type", "i64"}}); - auto Constant_1823 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_1824 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_1826 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1827 = - makeOP({ShapeOf_1821, Constant_1823, Constant_1824, Constant_1826}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Broadcast_1834 = makeOP({Constant_1833, StridedSlice_1827}, {{"mode", "numpy"}}); - auto Constant_894 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto ReduceSum_895 = makeOP({node_4, Constant_894}, {{"keep_dims", false}}); - auto Constant_898 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {4, 5}); - auto Unsqueeze_899 = makeOP({ReduceSum_895, Constant_898}); - auto Constant_430 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Unsqueeze_431 = makeOP({node_0, Constant_430}); - auto Constant_432 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {0, 1}); - auto Unsqueeze_433 = makeOP({Unsqueeze_431, Constant_432}); - auto Constant_437 = makeConst(element::i64, ov::Shape({}), {0}); - auto ShapeOf_434 = makeOP({Unsqueeze_433}, {{"output_type", "i64"}}); - auto Constant_435 = makeConst(element::i64, - ov::Shape({ - 2, - }), - {5, 6}); - auto Gather_439 = makeOP({ShapeOf_434, Constant_435, Constant_437}, {{"batch_dims", 0}}); - auto ReduceProd_851 = makeOP({Gather_439, Constant_437}, {{"keep_dims", true}}); - auto Constant_852 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_853 = - makeOP({Constant_437, ReduceProd_851, Constant_852}, {{"mode", "numpy"}}); - auto Gather_440 = makeOP({Gather_439, Constant_437, Constant_437}, {{"batch_dims", 0}}); - auto Constant_436 = makeConst(element::i64, ov::Shape({}), {2}); - auto Power_441 = makeOP({Gather_440, Constant_436}, {{"auto_broadcast", "numpy"}}); - auto Constant_438 = makeConst(element::i64, ov::Shape({}), {1}); - auto Subtract_442 = makeOP({Power_441, Constant_438}, {{"auto_broadcast", "numpy"}}); - auto Maximum_445 = makeOP({Subtract_442, Constant_438}, {{"auto_broadcast", "numpy"}}); - auto Subtract_443 = makeOP({Gather_440, Constant_438}, {{"auto_broadcast", "numpy"}}); - auto Maximum_444 = makeOP({Subtract_443, Constant_438}, {{"auto_broadcast", "numpy"}}); - auto Divide_446 = - makeOP({Maximum_445, Maximum_444}, {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto Range_447 = makeOP({Constant_437, Power_441, Divide_446}); - auto Unsqueeze_846 = makeOP({Gather_440, Constant_437}); - auto Constant_847 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_848 = - makeOP({Constant_438, Unsqueeze_846, Constant_847}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_882 = - makeOP({Broadcast_853, Range_447, Broadcast_848, Constant_437}); - auto ShapeOf_883 = makeOP({ShapeOf_434}); - auto Constant_884 = makeConst(element::u8, ov::Shape({}), {0}); - auto Broadcast_885 = makeOP({Constant_438, ShapeOf_883, Constant_884}, {{"mode", "numpy"}}); - auto ScatterElementsUpdate_888 = - makeOP({Broadcast_885, Constant_435, Gather_439, Constant_437}); - auto Reshape_889 = - makeOP({ScatterElementsUpdate_882, ScatterElementsUpdate_888}, {{"special_zero", false}}); - auto ConvertLike_890 = makeOP({Reshape_889, Unsqueeze_433}); - auto Multiply_891 = makeOP({Unsqueeze_433, ConvertLike_890}, {{"auto_broadcast", "numpy"}}); - auto Constant_892 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {6}); - auto ReduceSum_893 = makeOP({Multiply_891, Constant_892}, {{"keep_dims", false}}); - auto Constant_896 = makeConst(element::i64, - ov::Shape({ - 6, - }), - {0, 1, 2, 4, 3, 5}); - auto Transpose_897 = makeOP({ReduceSum_893, Constant_896}); - auto Multiply_900 = makeOP({Unsqueeze_899, Transpose_897}, {{"auto_broadcast", "numpy"}}); - auto Constant_1819 = makeConst(element::i64, - ov::Shape({ - 6, - }), - {3, 5, 0, 1, 2, 4}); - auto Transpose_1820 = makeOP({Multiply_900, Constant_1819}); - auto ShapeOf_1822 = makeOP({Transpose_1820}, {{"output_type", "i64"}}); - auto Constant_1828 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto Constant_1829 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_1831 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1832 = - makeOP({ShapeOf_1822, Constant_1828, Constant_1829, Constant_1831}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Broadcast_1837 = - makeOP({Broadcast_1834, StridedSlice_1832}, {{"mode", "bidirectional"}}); - auto ShapeOf_1840 = makeOP({Broadcast_1837}, {{"output_type", "i64"}}); - auto Constant_1851 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Constant_1841 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_1842 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {3}); - auto Constant_1844 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1845 = - makeOP({ShapeOf_1821, Constant_1841, Constant_1842, Constant_1844}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Broadcast_1852 = makeOP({Constant_1851, StridedSlice_1845}, {{"mode", "numpy"}}); - auto Constant_1846 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {5}); - auto Constant_1847 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {6}); - auto Constant_1849 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1850 = - makeOP({ShapeOf_1822, Constant_1846, Constant_1847, Constant_1849}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Broadcast_1855 = - makeOP({Broadcast_1852, StridedSlice_1850}, {{"mode", "bidirectional"}}); - auto ShapeOf_1858 = makeOP({Broadcast_1855}, {{"output_type", "i64"}}); - auto Concat_1859 = makeOP({ShapeOf_1840, ShapeOf_1858}, {{"axis", 0}}); - auto Broadcast_1860 = makeOP({ReduceSum_1818, Concat_1859}, {{"mode", "bidirectional"}}); - auto ReduceProd_1875 = makeOP({ShapeOf_1858, {0}}, {{"keep_dims", true}}); - auto Constant_1873 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto Concat_1876 = makeOP({ShapeOf_1840, ReduceProd_1875, Constant_1873}, {{"axis", 0}}); - auto Reshape_1903 = makeOP({Broadcast_1860, Concat_1876}, {{"special_zero", false}}); - auto Constant_1863 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {2}); - auto Constant_1864 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {5}); - auto Constant_1866 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {1}); - auto StridedSlice_1867 = - makeOP({ShapeOf_1822, Constant_1863, Constant_1864, Constant_1866}, - {{"begin_mask", {0}}, - {"end_mask", {0}}, - {"new_axis_mask", {}}, - {"shrink_axis_mask", {}}, - {"ellipsis_mask", {}}}); - auto Concat_1868 = makeOP({ShapeOf_1840, StridedSlice_1867, ShapeOf_1858}, {{"axis", 0}}); - auto Broadcast_1869 = makeOP({Transpose_1820, Concat_1868}, {{"mode", "bidirectional"}}); - auto Constant_1904 = makeConst(element::i64, - ov::Shape({ - 1, - }), - {0}); - auto ReduceProd_1905 = makeOP({StridedSlice_1867, Constant_1904}, {{"keep_dims", true}}); - auto ReduceProd_1907 = makeOP({ShapeOf_1858, {0}}, {{"keep_dims", true}}); - auto Concat_1908 = makeOP({ShapeOf_1840, ReduceProd_1905, ReduceProd_1907}, {{"axis", 0}}); - auto Reshape_1961 = makeOP({Broadcast_1869, Concat_1908}, {{"special_zero", false}}); - auto MatMul_1962 = - makeOP({Reshape_1903, Reshape_1961}, {{"transpose_a", true}, {"transpose_b", true}}); - auto Concat_1963 = makeOP({ShapeOf_1840, StridedSlice_1867}, {{"axis", 0}}); - auto Reshape_1964 = makeOP({MatMul_1962, Concat_1963}, {{"special_zero", false}}); - auto Constant_1965 = makeConst(element::i64, - ov::Shape({ - 5, - }), - {1, 2, 3, 4, 0}); - auto node_6 = makeOP({Reshape_1964, Constant_1965}); - model_ref = std::make_shared(NodeVector{node_6}, ParameterVector{node_0, node_2, node_4}); + auto data_1 = std::make_shared(element::f32, data_shape_1); + auto data_2 = std::make_shared(element::f32, data_shape_2); + auto data_3 = std::make_shared(element::f32, data_shape_3); + + // First pair of einsum inputs - data_1 and data_3 + // data_1 - label `a` can be reduced by reduce_input() + auto indice_of_a_in_data_1 = makeConst(element::i64, ov::Shape({1}), {0}); + auto data_1_processed = makeOP({data_1, indice_of_a_in_data_1}, {{"keep_dims", false}}); + // data_3 - unsqueeze ellipse labels to allow for broadcasting and handle repeated labels + auto ellipsis_idx = makeConst(element::i64, ov::Shape({1}), {0}); + auto data3_insert_missing_ellipsis = makeOP({data_3, ellipsis_idx}); + auto align_ellipsis_idx = makeConst(element::i64, ov::Shape({2}), {0, 1}); + auto data_3_processed = makeOP({data3_insert_missing_ellipsis, align_ellipsis_idx}); + auto data_3_diagonal = extract_diagonal(data_3_processed, {{5, 6}}); + + // No reduced labels - use simplified subgraph that uses Multiply instead Matmul + auto convenient_layout = makeConst(element::i64, ov::Shape({6}), {0, 1, 2, 4, 3, 5}); + // ...dbc -> ...bdc + auto rhs_convenient_layout = makeOP({data_3_diagonal, convenient_layout}); + // Optionally unsqueeze both operands for elementwise-multiplication with broadcasting + // For LHS operand, unsqueeze at RHS separate dimensions indices (placed at end of RHS by transpose) + auto lhs_unsqueeze_dims = makeConst(element::i64, ov::Shape({2}), {4, 5}); + auto lhs_unsqueeze = makeOP({data_1_processed, lhs_unsqueeze_dims}); + // Out subscript = LHS_subscript + RHS_separate_part_subscript + // ...bdc = ...b + dc + auto data_1_3 = makeOP({lhs_unsqueeze, rhs_convenient_layout}, {{"auto_broadcast", "numpy"}}); + + // Second pair of einsum inputs - data_2 and result of the first pair + // bcccdd,...bdc->c...b + // data_2 - handle repeated labels + auto data_2_diagonal = extract_diagonal(data_2, + { + {1, 2, 3}, // indices of repeated label c + {4, 5}, // indices_of_repeated_label_d + }); + // data_1_3 - transpose to correctly group common, separate and reduced labels + // ...bdc->bc...d + auto transpose_data_1_3_target = makeConst(element::i64, ov::Shape({6}), {3, 5, 0, 1, 2, 4}); + auto data_1_3_processed = makeOP({data_1_3, transpose_data_1_3_target}); + // Extract and broadcast common subshapes (bc) + auto shapeof_data_1_3 = makeOP({data_1_3_processed}, {{"output_type", "i64"}}); + auto common_data_1_3 = extract_subshape_from_shape(shapeof_data_1_3, 0, 2); + auto shapeof_data_2 = makeOP({data_2_diagonal}, {{"output_type", "i64"}}); + auto common_data_2 = extract_subshape_from_shape(shapeof_data_2, 0, 2); + auto common_broadcast_merge_shapes = broadcast_merge_shapes(common_data_2, common_data_1_3); + + // Extract and broadcast reduced subshapes (d) + auto reduced_data_2 = extract_subshape_from_shape(shapeof_data_2, 2, 3); + auto reduced_data_1_3 = extract_subshape_from_shape(shapeof_data_1_3, 5, 6); + auto reduced_broadcast_merge_shapes = broadcast_merge_shapes(reduced_data_2, reduced_data_1_3); + + // Extract and broadcast separate subshapes if needed + auto separate_data_1_3 = extract_subshape_from_shape(shapeof_data_1_3, 2, 5); // (...) + + // Broadcast data_2 and data_1_3 based on calculated subshapes + auto broadcast_data_2_target = + makeOP({common_broadcast_merge_shapes, reduced_broadcast_merge_shapes}, {{"axis", 0}}); + auto broadcast_data_2 = + makeOP({data_2_diagonal, broadcast_data_2_target}, {{"mode", "bidirectional"}}); + auto broadcast_data_1_3_target = + makeOP({common_broadcast_merge_shapes, separate_data_1_3, reduced_broadcast_merge_shapes}, + {{"axis", 0}}); + auto broadcast_data_1_3 = + makeOP({data_1_3_processed, broadcast_data_1_3_target}, {{"mode", "bidirectional"}}); + + // Optionally reshape broadcasted data_2 and data_1_3 so separate and reduced labels are represented by one + // dimension. Subgraphes are constant-folded, target subshapes are calculated broadcast_merge_shapes function. + auto reduced_prod = makeOP({reduced_broadcast_merge_shapes, {0}}, {{"keep_dims", true}}); + // Reshape data_2 + auto separate_data_2_placeholder = makeConst(element::i64, ov::Shape({1}), {1}); + auto reshape_data_2_target = + makeOP({common_broadcast_merge_shapes, reduced_prod, separate_data_2_placeholder}, + {{"axis", 0}}); + auto reshape_data_2 = + makeOP({broadcast_data_2, reshape_data_2_target}, {{"special_zero", false}}); + // Reshape data_1_3 + auto Constant_1904 = makeConst(element::i64, ov::Shape({1}), {0}); + auto separate_data_1_3_prod = + makeOP({separate_data_1_3, Constant_1904}, {{"keep_dims", true}}); + auto reshape_data_1_3_target = + makeOP({common_broadcast_merge_shapes, separate_data_1_3_prod, reduced_prod}, + {{"axis", 0}}); + auto reshape_data_1_3 = + makeOP({broadcast_data_1_3, reshape_data_1_3_target}, {{"special_zero", false}}); + auto matmul = + makeOP({reshape_data_2, reshape_data_1_3}, {{"transpose_a", true}, {"transpose_b", true}}); + auto reshape_out_subshape = + makeOP({common_broadcast_merge_shapes, separate_data_1_3}, {{"axis", 0}}); + auto reshape_out = makeOP({matmul, reshape_out_subshape}, {{"special_zero", false}}); + auto Constant_1965 = makeConst(element::i64, ov::Shape({5}), {1, 2, 3, 4, 0}); + auto transpose_out = makeOP({reshape_out, Constant_1965}); + model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1, data_2, data_3}); } } From a5d5db4ca2ba02f07a35c5ad5b97ad0448caeb5e Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 13 Feb 2025 12:54:56 +0100 Subject: [PATCH 37/47] Modify einsum broadcast_merge_shapes to use Maximum Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 25 +++++++------------ .../einsum_decomposition_test.cpp | 6 +---- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 05da3e0bd3dd15..4fd64c6fc21713 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -350,12 +350,11 @@ ov::Output unsqueeze_input(const ov::Output& input_node, return unsqueeze->output(0); } -/// \brief Broadcasts and merges two shapes using specified broadcasting rules. +/// \brief Broadcasts and merges two shapes of the same rank. /// -/// This function takes two shapes (shapes_lhs and shapes_rhs) and attempts to broadcast -/// and merge them into a single shape using NumPy and bidirectional broadcasting rules. The resulting -/// broadcasted shape is returned as an OutputVector. If one of the input vectors is empty, the other -/// vector is returned as is. +/// This function takes two shapes (shapes_lhs and shapes_rhs) of same rank and attempts to broadcast +/// and merge them into a single shape. The resulting broadcasted shape is returned as an OutputVector. +/// If one of the input vectors is empty, the other vector is returned as is. /// /// \param shapes_lhs A single element vector containing the left-hand side shape to be broadcasted or empty. /// \param shapes_rhs A single element vector containing the right-hand side shape to be broadcasted or empty. @@ -369,17 +368,11 @@ ov::OutputVector broadcast_merge_shapes(ov::OutputVector& shapes_lhs, ov::OutputVector broadcasted_shape_nodes{}; // OutputVector is either empty or contains a single shape if (shapes_lhs.size() == 1 && shapes_rhs.size() == 1) { - auto const_1 = ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{1}, {1}); - auto tmp_const_of_lhs_shp = - std::make_shared(const_1, shapes_lhs[0], ov::op::BroadcastType::NUMPY); - auto tmp_const_of_broadcasted_shp = - std::make_shared(tmp_const_of_lhs_shp, - shapes_rhs[0], - ov::op::BroadcastType::BIDIRECTIONAL); - auto broadcasted_shape = std::make_shared(tmp_const_of_broadcasted_shp); - broadcasted_shape_nodes.push_back(broadcasted_shape->output(0)); - subgraph_nodes.insert(subgraph_nodes.end(), - {const_1, tmp_const_of_lhs_shp, tmp_const_of_broadcasted_shp, broadcasted_shape}); + // For common and reference subshapes, same rank should already be ensured by function + // `unsqueeze_ellipses_to_same_rank`. + const auto& maximum = std::make_shared(shapes_lhs[0], shapes_rhs[0]); + subgraph_nodes.push_back(maximum); + broadcasted_shape_nodes.push_back(maximum); } else if (shapes_lhs.size() == 0 && shapes_rhs.size() == 1) { return shapes_rhs; } else if (shapes_lhs.size() == 1 && shapes_rhs.size() == 0) { diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 177f2c2ba46f3c..55fc5a9c3a1fb6 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -31,11 +31,7 @@ std::shared_ptr extract_subshape_from_shape(const std::shared_ptr broadcast_merge_shapes(const std::shared_ptr& shape_node_lhs, const std::shared_ptr& shape_node_rhs) { - auto const_1 = makeConst(element::i64, ov::Shape({1}), {1}); - auto tensor_of_lhs_shape = makeOP({const_1, shape_node_lhs}, {{"mode", "numpy"}}); - auto tensor_of_broadcasted_lhs_rhs_shape = - makeOP({tensor_of_lhs_shape, shape_node_rhs}, {{"mode", "bidirectional"}}); - auto broadcasted_shapes = makeOP({tensor_of_broadcasted_lhs_rhs_shape}, {{"output_type", "i64"}}); + auto broadcasted_shapes = makeOP({shape_node_lhs, shape_node_rhs}, {{"auto_broadcast", "numpy"}}); return broadcasted_shapes; } From 8c447296e3fbfac11d4cf8861050e843e18b8ef5 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 13 Feb 2025 12:55:58 +0100 Subject: [PATCH 38/47] Fix typo in einsum decomposition Signed-off-by: Mateusz Mikolajczyk --- .../src/transformations/op_conversions/einsum_decomposition.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 4fd64c6fc21713..8034bb3180b04c 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -642,7 +642,7 @@ void reduce_input(ov::OutputVector& input_nodes, /// \brief Builds an n-dimensional identity tensor based on the input node and repeated label dimensions. /// -/// This function constructs an identity tenosor matching number of dimensions of the number of repeats for a single +/// This function constructs an identity tensor matching number of dimensions of the number of repeats for a single /// label. /// /// \param input_node The input node for which the identity tensor is to be built. From e273a50ff64be1d4a47363795ec7333177727a95 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Tue, 18 Feb 2025 14:07:22 +0100 Subject: [PATCH 39/47] Refactor handling of repeated labels in einsum decomposition to not create multi-identity Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 256 ++++++++---------- 1 file changed, 115 insertions(+), 141 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 8034bb3180b04c..2040e320547930 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -19,6 +19,7 @@ #include "openvino/op/matmul.hpp" #include "openvino/op/maximum.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/pad.hpp" #include "openvino/op/power.hpp" #include "openvino/op/range.hpp" #include "openvino/op/reduce_prod.hpp" @@ -26,6 +27,7 @@ #include "openvino/op/reshape.hpp" #include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" #include "openvino/op/strided_slice.hpp" #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" @@ -554,7 +556,6 @@ void transpose_input(ov::OutputVector& input_nodes, const auto& input_node = input_nodes[input_ind]; const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); const auto required_labels = ov::op::v7::Einsum::extract_labels(required_subscript); - OPENVINO_ASSERT(labels.size() == required_labels.size()); const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); for (const auto& required_label : required_labels) { const auto label_dims_it = label_dim_map.find(required_label); @@ -640,120 +641,6 @@ void reduce_input(ov::OutputVector& input_nodes, subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); } -/// \brief Builds an n-dimensional identity tensor based on the input node and repeated label dimensions. -/// -/// This function constructs an identity tensor matching number of dimensions of the number of repeats for a single -/// label. -/// -/// \param input_node The input node for which the identity tensor is to be built. -/// \param repeated_label_dims A vector containing the dimensions of the repeated label. -/// \param subgraph_nodes A vector of operation nodes that is included into -/// a sub-graph decomposing Einsum that is needed for copy_runtime_info -/// -/// \return The final node representing the identity tensor, reshaped to match input rank and correct dimensions -/// with repeated labels. -ov::Output build_identity(const ov::Output& input_node, - const std::vector& repeated_label_dims, - ov::NodeVector& subgraph_nodes) { - OPENVINO_ASSERT(repeated_label_dims.size() > 1); - // Create flattened (repeated_label_dims.size())-dimensional eye tensor with 1s on the diagonal. - const auto input_shape = std::make_shared(input_node); - const auto repeated_label_indices = - ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); - const auto repeated_label_indices_len = - ov::op::v0::Constant::create(ov::element::i64, {}, {repeated_label_dims.size()}); - const auto const_0 = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); - const auto const_1 = ov::op::v0::Constant::create(ov::element::i64, {}, {1}); - const auto repeated_dimensions = std::make_shared(input_shape, repeated_label_indices, const_0); - const auto repeated_dimension = std::make_shared(repeated_dimensions, const_0, const_0); - const auto range_max_val = std::make_shared(repeated_dimension, repeated_label_indices_len); - const auto step_numerator = std::make_shared(range_max_val, const_1); - const auto step_denominator = std::make_shared(repeated_dimension, const_1); - const auto step_denominator_but_not_0 = std::make_shared(step_denominator, const_1); - const auto step_numerator_but_not_0 = std::make_shared(step_numerator, const_1); - const auto step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); - const auto eye_flattened_indices = std::make_shared(const_0, range_max_val, step); - const auto repeated_dimension_1d = std::make_shared(repeated_dimension, const_0); - const auto ones = std::make_shared(const_1, repeated_dimension_1d); - const auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); - const auto zeros = std::make_shared(const_0, reduced_size); - const auto eye_flattened = - std::make_shared(zeros, eye_flattened_indices, ones, const_0); - - // Prepare target shape for identity tensor for specified repeated label dimensions. - const auto identity_rank = std::make_shared(input_shape); - const auto ones_of_input_shape_rank = std::make_shared(const_1, identity_rank); - const auto identity_shape = std::make_shared(ones_of_input_shape_rank, - repeated_label_indices, - repeated_dimensions, - const_0); - - // Reshape the flattened identity tensor to the target shape. - const auto identity = std::make_shared(eye_flattened, identity_shape, false); - subgraph_nodes.insert(subgraph_nodes.end(), - {input_shape, - repeated_label_indices, - repeated_label_indices_len, - const_0, - const_1, - repeated_dimensions, - repeated_dimension, - range_max_val, - step_numerator, - step_denominator, - step_denominator_but_not_0, - step_numerator_but_not_0, - step, - eye_flattened_indices, - repeated_dimension_1d, - ones, - reduced_size, - zeros, - eye_flattened, - identity_rank, - ones_of_input_shape_rank, - identity_shape, - identity}); - return subgraph_nodes.back(); -} - -/// \brief Builds a multi-identity node by multiplying identity nodes for each repeated label. -/// -/// This function constructs a multi-identity node by iteratively multiplying identity nodes -/// corresponding to each repeated label. The identity nodes are built using the provided -/// input node and label dimension map. -/// -/// \param input_node The input node for which the identity nodes are to be built. -/// \param repeated_labels A vector of repeated labels for which identity nodes are to be created. -/// \param label_dim_map A map from labels to their corresponding dimensions. -/// \param subgraph_nodes A vector of operation nodes that is included into -/// a sub-graph decomposing Einsum that is needed for copy_runtime_info -/// \return The final multi-identity node after multiplying all identity nodes. -/// -ov::Output build_multi_identity(const ov::Output& input_node, - const std::vector& repeated_labels, - const LabelDimMap& label_dim_map, - ov::NodeVector& subgraph_nodes) { - OPENVINO_ASSERT(repeated_labels.size() > 0); - - const auto get_identity = [&](size_t idx) { - const auto repeated_label_dims = label_dim_map.find(repeated_labels[idx]); - OPENVINO_ASSERT(repeated_label_dims != label_dim_map.end()); - return build_identity(input_node, repeated_label_dims->second, subgraph_nodes); - }; - - // initially set multi-identity with identity for the first repeated label - auto multi_identity = get_identity(0).get_node_shared_ptr(); - for (size_t label_ind = 1; label_ind < repeated_labels.size(); ++label_ind) { - const auto identity = get_identity(label_ind); - multi_identity = - std::make_shared(multi_identity, identity, ov::op::AutoBroadcastType::NUMPY); - subgraph_nodes.insert(subgraph_nodes.end(), {multi_identity}); - } - - return subgraph_nodes.back(); -} - /// \brief Prepares data for diagonal extraction in Einsum operation. /// /// This function processes the input subscript and label-dimension map to identify repeated labels, @@ -762,17 +649,22 @@ ov::Output build_multi_identity(const ov::Output& input_node /// \param input_subscript The input subscript string representing the Einsum equation. /// \param label_dim_map A map from labels to their corresponding dimensions. /// \param resultant_subscript A reference to the resultant subscript string to be updated. -/// \param repeated_labels A reference to a vector of strings to store repeated labels found in the input subscript. -/// \param reduced_axes A reference to an AxisSet to store the axes that need to be reduced. +/// \param resultant_subscript_with_duplicates A reference to a resultant subscript string with duplicates. +/// \param repeated_labels A reference to a vector of strings to store repeated labels found in input subscript. +/// \param unrepeated_labels A reference to a vector where unrepeated labels will be stored. +/// \param reduced_axes A reference to an AxisVector to store the axes that need to be reduced. /// void prepare_diagonal_extraction_data(const std::string& input_subscript, const LabelDimMap& label_dim_map, std::string& resultant_subscript, + std::string& resultant_subscript_with_duplicates, std::vector& repeated_labels, - ov::AxisSet& reduced_axes) { + std::vector& unrepeated_labels, + ov::AxisVector& reduced_axes) { static const std::string ellipsis = "..."; const auto labels = ov::op::v7::Einsum::extract_labels(input_subscript); - + std::vector repeated_labels_with_duplicates; + size_t reduced_dim = 1; for (const auto& label : labels) { if (resultant_subscript.find(label) != std::string::npos) { continue; @@ -787,23 +679,30 @@ void prepare_diagonal_extraction_data(const std::string& input_subscript, if (label != ellipsis && dims_size > 1) { // repeated label is found - for (size_t dim_ind = 1; dim_ind < dims_size; ++dim_ind) { - reduced_axes.insert(dims[dim_ind]); - } // save only the first dimension corresponding to the repeated label dims = {dims[0]}; + reduced_axes.push_back(reduced_dim); repeated_labels.push_back(label); + repeated_labels_with_duplicates.insert(repeated_labels_with_duplicates.end(), dims_size, label); + reduced_dim += 2; + resultant_subscript += label; + } else { + unrepeated_labels.push_back(label); } - resultant_subscript += label; } + resultant_subscript = std::accumulate(unrepeated_labels.begin(), unrepeated_labels.end(), resultant_subscript); + resultant_subscript_with_duplicates = std::accumulate(repeated_labels_with_duplicates.begin(), + repeated_labels_with_duplicates.end(), + std::string("")); + resultant_subscript_with_duplicates = + std::accumulate(unrepeated_labels.begin(), unrepeated_labels.end(), resultant_subscript_with_duplicates); } /// /// \brief Extracts the diagonal elements from the input tensor based on the provided subscripts. /// /// This function modifies the input tensor by extracting its diagonal elements for repeated lables and updating the -/// corresponding subscript. The diagonal extraction is performed by multiplying the input tensor with a -/// multi-identity. +/// corresponding subscript. /// /// \param inputs A vector of input tensors. /// \param input_subscripts A vector of subscripts corresponding to each input tensor. @@ -816,7 +715,7 @@ void extract_diagonal(ov::OutputVector& inputs, size_t input_ind, ov::NodeVector& subgraph_nodes) { // perform sanity check for arguments - const auto num_inputs = inputs.size(); + const auto& num_inputs = inputs.size(); OPENVINO_ASSERT(num_inputs == input_subscripts.size(), "Each input must have own subscript."); OPENVINO_ASSERT(input_ind < num_inputs, "Input index is out of range."); @@ -825,32 +724,107 @@ void extract_diagonal(ov::OutputVector& inputs, const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); std::string resultant_subscript; + std::string resultant_subscript_with_duplicates; std::vector repeated_labels; - ov::AxisSet reduced_axes; + std::vector unrepeated_labels; + ov::AxisVector reduced_axes; prepare_diagonal_extraction_data(input_subscript, label_dim_map, resultant_subscript, + resultant_subscript_with_duplicates, repeated_labels, + unrepeated_labels, reduced_axes); if (repeated_labels.size() == 0) { return; } - const auto multi_identity = build_multi_identity(input_node, repeated_labels, label_dim_map, subgraph_nodes); - - // multiply both operands with broadcasting - const auto multi_identity_converted = std::make_shared(multi_identity, input_node); - const auto mul = - std::make_shared(input_node, multi_identity_converted, ov::op::AutoBroadcastType::NUMPY); - subgraph_nodes.insert(subgraph_nodes.end(), {multi_identity_converted, mul}); - - const std::vector reduced_axes_vec{reduced_axes.cbegin(), reduced_axes.cend()}; - const auto axes_const = - ov::op::v0::Constant::create(ov::element::Type_t::i64, ov::Shape{reduced_axes.size()}, reduced_axes_vec); - const auto reduce_sum = std::make_shared(mul->output(0), axes_const, false); - subgraph_nodes.insert(subgraph_nodes.end(), {axes_const, reduce_sum}); + // Transpose input so that repeated labels are grouped by same label and un-repeated labels are moved to the end + transpose_input(inputs, input_subscripts, resultant_subscript, input_ind, subgraph_nodes); + const auto& input_shape = std::make_shared(input_node); + ov::NodeVector begins; + ov::NodeVector ends; + const auto transposed_label_dim_map = + compute_label_dim_map(input_node.get_partial_shape().rank(), resultant_subscript_with_duplicates); + const auto& const_0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + const auto& const_1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + subgraph_nodes.insert(subgraph_nodes.end(), {input_shape, const_0, const_1}); + ov::NodeVector convenient_shape_vector; + ov::NodeVector shape_after_pad_vector; + for (std::string repeated_label : repeated_labels) { + const auto dim_map_repeated_label = transposed_label_dim_map.find(repeated_label); + OPENVINO_ASSERT(dim_map_repeated_label != transposed_label_dim_map.end()); + const auto& repeated_label_dims = dim_map_repeated_label->second; + const auto& repeated_label_indices = + ov::op::v0::Constant::create(ov::element::i64, {repeated_label_dims.size()}, repeated_label_dims); + const auto& repeated_label_indices_len = + ov::op::v0::Constant::create(ov::element::i64, {}, {repeated_label_dims.size()}); + const auto& repeated_dimensions = + std::make_shared(input_shape, repeated_label_indices, const_0); + const auto& repeated_dimension = std::make_shared(repeated_dimensions, const_0, const_0); + const auto& range_max_val = std::make_shared(repeated_dimension, repeated_label_indices_len); + const auto& step_numerator = std::make_shared(range_max_val, const_1); + const auto& step_denominator = std::make_shared(repeated_dimension, const_1); + const auto& step_denominator_but_not_0 = std::make_shared(step_denominator, const_1); + const auto& step_numerator_but_not_0 = std::make_shared(step_numerator, const_1); + const auto& step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); + const auto& end = std::make_shared(step, const_1); + const auto& reduced_size = std::make_shared(repeated_dimensions, const_0, true); + convenient_shape_vector.push_back(reduced_size); + shape_after_pad_vector.push_back(repeated_dimension); + shape_after_pad_vector.push_back(step); + begins.push_back(const_0); + ends.push_back(end); + subgraph_nodes.insert(subgraph_nodes.end(), + {repeated_label_indices, + repeated_label_indices_len, + repeated_dimensions, + repeated_dimension, + range_max_val, + step_numerator, + step_denominator, + step_denominator_but_not_0, + step_numerator_but_not_0, + step, + end, + reduced_size}); + } + for (std::string unrepeated_label : unrepeated_labels) { + const auto& dim_map_unrepeated_label = transposed_label_dim_map.find(unrepeated_label); + OPENVINO_ASSERT(dim_map_unrepeated_label != transposed_label_dim_map.end()); + const auto& unrepeated_label_dims = dim_map_unrepeated_label->second; + const auto& unrepeated_label_indices = + ov::op::v0::Constant::create(ov::element::i64, {unrepeated_label_dims.size()}, unrepeated_label_dims); + const auto& unrepeated_dimensions = + std::make_shared(input_shape, unrepeated_label_indices, const_0); + convenient_shape_vector.push_back(unrepeated_dimensions); + shape_after_pad_vector.push_back(unrepeated_dimensions); + begins.insert(begins.end(), unrepeated_label_dims.size(), const_0); + ends.insert(ends.end(), unrepeated_label_dims.size(), const_0); + subgraph_nodes.insert(subgraph_nodes.end(), {unrepeated_label_indices, unrepeated_dimensions}); + } + const auto& convenient_shape = std::make_shared(convenient_shape_vector, 0); + const auto& pads_end = std::make_shared(ends, 0); + const auto& pads_begin = std::make_shared(begins, 0); + const auto& reshaped_input = std::make_shared(input_node, convenient_shape, false); + const auto& pad = + std::make_shared(reshaped_input, pads_begin, pads_end, ov::op::PadMode::CONSTANT); + const auto& reshape_after_pad_target = std::make_shared(shape_after_pad_vector, 0); + const auto& reshape_after_pad = std::make_shared(pad, reshape_after_pad_target, false); + subgraph_nodes.insert( + subgraph_nodes.end(), + {convenient_shape, pads_begin, pads_end, reshaped_input, pad, reshape_after_pad_target, reshape_after_pad}); + std::shared_ptr gather = reshape_after_pad; + for (auto axis : reduced_axes) { + auto axis_const = ov::op::v0::Constant::create(ov::element::i64, {1}, {axis}); + gather = std::make_shared(gather, const_0, axis_const); + subgraph_nodes.insert(subgraph_nodes.end(), {axis_const, gather}); + } + const auto& reduced_indices = ov::op::v0::Constant::create(ov::element::i64, {reduced_axes.size()}, reduced_axes); + const auto& out_node = std::make_shared(gather, reduced_indices); + subgraph_nodes.insert(subgraph_nodes.end(), {reduced_indices, out_node}); - inputs[input_ind] = reduce_sum->output(0); + inputs[input_ind] = out_node->output(0); input_subscripts[input_ind] = resultant_subscript; } From 3fc4a70837176562d9b075b2aa1c753dac85ea9e Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Tue, 18 Feb 2025 18:25:39 +0100 Subject: [PATCH 40/47] Remove unnecesary loop Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 2040e320547930..6f79316b4fd2b4 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -793,16 +793,21 @@ void extract_diagonal(ov::OutputVector& inputs, const auto& dim_map_unrepeated_label = transposed_label_dim_map.find(unrepeated_label); OPENVINO_ASSERT(dim_map_unrepeated_label != transposed_label_dim_map.end()); const auto& unrepeated_label_dims = dim_map_unrepeated_label->second; - const auto& unrepeated_label_indices = - ov::op::v0::Constant::create(ov::element::i64, {unrepeated_label_dims.size()}, unrepeated_label_dims); - const auto& unrepeated_dimensions = - std::make_shared(input_shape, unrepeated_label_indices, const_0); - convenient_shape_vector.push_back(unrepeated_dimensions); - shape_after_pad_vector.push_back(unrepeated_dimensions); + unrepeated_dimension_indices_vec.insert(unrepeated_dimension_indices_vec.end(), + unrepeated_label_dims.begin(), + unrepeated_label_dims.end()); begins.insert(begins.end(), unrepeated_label_dims.size(), const_0); ends.insert(ends.end(), unrepeated_label_dims.size(), const_0); - subgraph_nodes.insert(subgraph_nodes.end(), {unrepeated_label_indices, unrepeated_dimensions}); } + const auto& unrepeated_dimensions_indices = ov::op::v0::Constant::create(ov::element::i64, + {unrepeated_dimension_indices_vec.size()}, + unrepeated_dimension_indices_vec); + const auto& unrepeated_dimensions = + std::make_shared(input_shape, unrepeated_dimensions_indices, const_0); + subgraph_nodes.insert(subgraph_nodes.end(), {unrepeated_dimensions_indices, unrepeated_dimensions}); + convenient_shape_vector.push_back(unrepeated_dimensions); + shape_after_pad_vector.push_back(unrepeated_dimensions); + const auto& convenient_shape = std::make_shared(convenient_shape_vector, 0); const auto& pads_end = std::make_shared(ends, 0); const auto& pads_begin = std::make_shared(begins, 0); From 153aaaad274eea59570850dd8d1e98b4fc1f3f92 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 19 Feb 2025 14:07:47 +0100 Subject: [PATCH 41/47] Fix einsum decomposition Signed-off-by: Mateusz Mikolajczyk --- .../src/transformations/op_conversions/einsum_decomposition.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 6f79316b4fd2b4..be66856a56abeb 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -789,6 +789,7 @@ void extract_diagonal(ov::OutputVector& inputs, end, reduced_size}); } + std::vector unrepeated_dimension_indices_vec; for (std::string unrepeated_label : unrepeated_labels) { const auto& dim_map_unrepeated_label = transposed_label_dim_map.find(unrepeated_label); OPENVINO_ASSERT(dim_map_unrepeated_label != transposed_label_dim_map.end()); From cd7658bc1a6a2169a5b043254354ee07cf6f0a0d Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 19 Feb 2025 17:27:42 +0100 Subject: [PATCH 42/47] Update einsum decomposition test + add inline comments with descriptions Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 29 +- .../einsum_decomposition_test.cpp | 362 ++++++++++++------ 2 files changed, 267 insertions(+), 124 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index be66856a56abeb..edf091deebcaa3 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -701,7 +701,7 @@ void prepare_diagonal_extraction_data(const std::string& input_subscript, /// /// \brief Extracts the diagonal elements from the input tensor based on the provided subscripts. /// -/// This function modifies the input tensor by extracting its diagonal elements for repeated lables and updating the +/// This function modifies the input tensor by extracting its diagonal elements for repeated labels and updating the /// corresponding subscript. /// /// \param inputs A vector of input tensors. @@ -722,12 +722,15 @@ void extract_diagonal(ov::OutputVector& inputs, const auto& input_node = inputs[input_ind]; const auto& input_subscript = input_subscripts[input_ind]; + // Compute the label to dimension map for the input subscript const auto label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), input_subscript); std::string resultant_subscript; std::string resultant_subscript_with_duplicates; std::vector repeated_labels; std::vector unrepeated_labels; ov::AxisVector reduced_axes; + + // Prepare data for diagonal extraction prepare_diagonal_extraction_data(input_subscript, label_dim_map, resultant_subscript, @@ -736,14 +739,20 @@ void extract_diagonal(ov::OutputVector& inputs, unrepeated_labels, reduced_axes); + // If there are no repeated labels, return early if (repeated_labels.size() == 0) { return; } + // Transpose input so that repeated labels are grouped by same label and un-repeated labels are moved to the end transpose_input(inputs, input_subscripts, resultant_subscript, input_ind, subgraph_nodes); + + // Create a ShapeOf operation to get the shape of the input tensor const auto& input_shape = std::make_shared(input_node); ov::NodeVector begins; ov::NodeVector ends; + + // Compute the label to dimension map for the transposed input subscript const auto transposed_label_dim_map = compute_label_dim_map(input_node.get_partial_shape().rank(), resultant_subscript_with_duplicates); const auto& const_0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); @@ -751,6 +760,8 @@ void extract_diagonal(ov::OutputVector& inputs, subgraph_nodes.insert(subgraph_nodes.end(), {input_shape, const_0, const_1}); ov::NodeVector convenient_shape_vector; ov::NodeVector shape_after_pad_vector; + + // Process each repeated label for (std::string repeated_label : repeated_labels) { const auto dim_map_repeated_label = transposed_label_dim_map.find(repeated_label); OPENVINO_ASSERT(dim_map_repeated_label != transposed_label_dim_map.end()); @@ -770,7 +781,9 @@ void extract_diagonal(ov::OutputVector& inputs, const auto& step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); const auto& end = std::make_shared(step, const_1); const auto& reduced_size = std::make_shared(repeated_dimensions, const_0, true); + // Flatten dimensions of repeated label convenient_shape_vector.push_back(reduced_size); + // Compute the new shape after padding, separate diagonal elements shape_after_pad_vector.push_back(repeated_dimension); shape_after_pad_vector.push_back(step); begins.push_back(const_0); @@ -789,6 +802,8 @@ void extract_diagonal(ov::OutputVector& inputs, end, reduced_size}); } + + // Process unrepeated labels - do not modify or pad dimensions std::vector unrepeated_dimension_indices_vec; for (std::string unrepeated_label : unrepeated_labels) { const auto& dim_map_unrepeated_label = transposed_label_dim_map.find(unrepeated_label); @@ -800,6 +815,8 @@ void extract_diagonal(ov::OutputVector& inputs, begins.insert(begins.end(), unrepeated_label_dims.size(), const_0); ends.insert(ends.end(), unrepeated_label_dims.size(), const_0); } + + // Gather the dimensions for unrepeated labels in single call const auto& unrepeated_dimensions_indices = ov::op::v0::Constant::create(ov::element::i64, {unrepeated_dimension_indices_vec.size()}, unrepeated_dimension_indices_vec); @@ -809,27 +826,35 @@ void extract_diagonal(ov::OutputVector& inputs, convenient_shape_vector.push_back(unrepeated_dimensions); shape_after_pad_vector.push_back(unrepeated_dimensions); + // Create the new shape for the input tensor that would flatten repeated label dimensions const auto& convenient_shape = std::make_shared(convenient_shape_vector, 0); + const auto& reshaped_input = std::make_shared(input_node, convenient_shape, false); + // Create the pads for the label-flattened input tensor to extract the diagonal elements const auto& pads_end = std::make_shared(ends, 0); const auto& pads_begin = std::make_shared(begins, 0); - const auto& reshaped_input = std::make_shared(input_node, convenient_shape, false); const auto& pad = std::make_shared(reshaped_input, pads_begin, pads_end, ov::op::PadMode::CONSTANT); + // Reshape the tensor after padding to extract the diagonal elements to separate dimensions const auto& reshape_after_pad_target = std::make_shared(shape_after_pad_vector, 0); const auto& reshape_after_pad = std::make_shared(pad, reshape_after_pad_target, false); subgraph_nodes.insert( subgraph_nodes.end(), {convenient_shape, pads_begin, pads_end, reshaped_input, pad, reshape_after_pad_target, reshape_after_pad}); + + // Gather the diagonal elements std::shared_ptr gather = reshape_after_pad; for (auto axis : reduced_axes) { auto axis_const = ov::op::v0::Constant::create(ov::element::i64, {1}, {axis}); gather = std::make_shared(gather, const_0, axis_const); subgraph_nodes.insert(subgraph_nodes.end(), {axis_const, gather}); } + + // Squeeze the gathered tensor to remove the reduced axes const auto& reduced_indices = ov::op::v0::Constant::create(ov::element::i64, {reduced_axes.size()}, reduced_axes); const auto& out_node = std::make_shared(gather, reduced_indices); subgraph_nodes.insert(subgraph_nodes.end(), {reduced_indices, out_node}); + // Update the input tensor and its subscript inputs[input_ind] = out_node->output(0); input_subscripts[input_ind] = resultant_subscript; } diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 55fc5a9c3a1fb6..2e8a4105bc26b6 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -34,67 +34,130 @@ std::shared_ptr broadcast_merge_shapes(const std::shared_ptr auto broadcasted_shapes = makeOP({shape_node_lhs, shape_node_rhs}, {{"auto_broadcast", "numpy"}}); return broadcasted_shapes; } - -std::shared_ptr create_identity(const std::shared_ptr& data, - const std::vector& repated_label_indices) { - auto shapeof_data = makeOP({data}, {{"output_type", "i64"}}); - auto rankof_data = makeOP({shapeof_data}); - auto const_0 = makeConst(element::i64, ov::Shape({}), {0}); - auto const_1 = makeConst(element::i64, ov::Shape({}), {1}); - auto num_of_repeated_labels = makeConst(element::i64, ov::Shape({}), {repated_label_indices.size()}); - auto repeated_label_indices = makeConst(element::i64, - ov::Shape({ - repated_label_indices.size(), - }), - repated_label_indices); - auto repeated_dimensions = - makeOP({shapeof_data, repeated_label_indices, const_0}, {{"batch_dims", 0}}); - auto repeated_dimensions_size = makeOP({repeated_dimensions, const_0}, {{"keep_dims", true}}); - auto zeros_of_size = makeOP({const_0, repeated_dimensions_size}, {{"mode", "numpy"}}); - auto repeated_dimension = makeOP({repeated_dimensions, const_0, const_0}, {{"batch_dims", 0}}); - auto range_max_val = - makeOP({repeated_dimension, num_of_repeated_labels}, {{"auto_broadcast", "numpy"}}); - auto step_numerator = makeOP({range_max_val, const_1}, {{"auto_broadcast", "numpy"}}); - auto step_numerator_but_not_0 = makeOP({step_numerator, const_1}, {{"auto_broadcast", "numpy"}}); - auto step_denominator = makeOP({repeated_dimension, const_1}, {{"auto_broadcast", "numpy"}}); - auto step_denominator_but_not_0 = - makeOP({step_denominator, const_1}, {{"auto_broadcast", "numpy"}}); - auto step = makeOP({step_numerator_but_not_0, step_denominator_but_not_0}, - {{"auto_broadcast", "numpy"}, {"m_pythondiv", true}}); - auto eye_flattened_indices = makeOP({const_0, range_max_val, step}); - auto repeated_dimension_1d = makeOP({repeated_dimension, const_0}); - auto ones = makeOP({const_1, repeated_dimension_1d}, {{"mode", "numpy"}}); - auto eye_flattened = makeOP({zeros_of_size, eye_flattened_indices, ones, const_0}); - auto ones_of_input_shape_rank = makeOP({const_1, rankof_data}, {{"mode", "numpy"}}); - auto identity_shape = makeOP( - {ones_of_input_shape_rank, repeated_label_indices, repeated_dimensions, const_0}); - auto identity = makeOP({eye_flattened, identity_shape}, {{"special_zero", false}}); - return identity; -} - +/// +/// \brief Extracts the diagonal elements from the input tensor based on the specified repeated and unrepeated indices. +/// +/// This function performs a series of operations to extract the diagonal elements from the input tensor `data`. +/// It first transposes the input tensor based on the repeated and unrepeated indices, then reshapes and pads the tensor +/// to isolate the diagonal elements. Finally, it gathers and squeezes the tensor to obtain the diagonal elements. +/// +/// \param data A shared pointer to the input tensor node. +/// \param indices_of_repeated_labels A vector of vectors containing the indices of repeated labels. +/// \param unrepeated_indices A vector containing the indices of unrepeated labels. Default is an empty vector. +/// \return A shared pointer to the node representing the diagonal elements of the input tensor. std::shared_ptr extract_diagonal(const std::shared_ptr& data, - const std::vector>& indices_of_repeated_labels) { - // Initialize multi_identity by identity for first repeated label. - auto multi_identity = create_identity(data, indices_of_repeated_labels[0]); - // Initialize reduction axes by all except first repated_label_indices for first repeated label. - std::vector reduce_axes(indices_of_repeated_labels[0].begin() + 1, indices_of_repeated_labels[0].end()); - // Merge remaining identities. - for (size_t i = 1; i < indices_of_repeated_labels.size(); i++) { - auto identity = create_identity(data, indices_of_repeated_labels[i]); - multi_identity = makeOP({multi_identity, identity}, {{"auto_broadcast", "numpy"}}); - reduce_axes.insert(reduce_axes.end(), - indices_of_repeated_labels[i].begin() + 1, - indices_of_repeated_labels[i].end()); + const std::vector>& indices_of_repeated_labels, + const std::vector& unrepeated_indices = {}) { + std::vector transpose_group_labels_target; + std::vector reduced_axes; + + // Prepare the target order for transposing the input tensor + for (size_t i = 0; i < indices_of_repeated_labels.size(); i++) { + auto repeated_label = indices_of_repeated_labels[i]; + size_t step = i * 2; + reduced_axes.push_back(step + 1); + transpose_group_labels_target.insert(transpose_group_labels_target.end(), + repeated_label.begin(), + repeated_label.end()); + } + transpose_group_labels_target.insert(transpose_group_labels_target.end(), + unrepeated_indices.begin(), + unrepeated_indices.end()); + + // Transpose the input tensor to group repeated and unrepeated labels + auto const_transpose_group_labels_target = + makeConst(element::i64, ov::Shape({transpose_group_labels_target.size()}), transpose_group_labels_target); + auto transpose_group_labels = std::make_shared(data, const_transpose_group_labels_target); + + // Get the shape of the transposed tensor + auto shapeof_transposed_data = std::make_shared(transpose_group_labels); + + auto const_0 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape({1}), {0}); + auto const_1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape({1}), {1}); + + ov::NodeVector flattened_shapes; + ov::NodeVector unflattened_shapes; + ov::NodeVector begins; + ov::NodeVector ends; + + std::vector dim_iota(transpose_group_labels_target.size()); + std::iota(dim_iota.begin(), dim_iota.end(), 0); + + size_t dimension_iter = 0; + + // Process each repeated label group + for (auto repeated_label : indices_of_repeated_labels) { + auto num_repeats = repeated_label.size(); + std::vector label_indices = {dim_iota.begin() + dimension_iter, + dim_iota.begin() + dimension_iter + num_repeats}; + auto repeated_label_indices_len = ov::op::v0::Constant::create(ov::element::i64, {}, {num_repeats}); + auto repeated_label_indices = + ov::op::v0::Constant::create(ov::element::i64, ov::Shape({num_repeats}), label_indices); + auto repeated_dimensions = + std::make_shared(shapeof_transposed_data, repeated_label_indices, const_0); + auto repeated_dimension = std::make_shared(repeated_dimensions, const_0, const_0); + auto range_max_val = std::make_shared(repeated_dimension, repeated_label_indices_len); + auto step_numerator = std::make_shared(range_max_val, const_1); + auto step_denominator = std::make_shared(repeated_dimension, const_1); + auto step_denominator_but_not_0 = std::make_shared(step_denominator, const_1); + auto step_numerator_but_not_0 = std::make_shared(step_numerator, const_1); + auto step = std::make_shared(step_numerator_but_not_0, step_denominator_but_not_0); + auto end = std::make_shared(step, const_1); + // Flatten all dimensions of single repeated label. + auto reduced_size = std::make_shared(repeated_dimensions, const_0, true); + flattened_shapes.push_back(reduced_size); + // Reshape the tensor to restore the original shape with diagonal elements isolated and remainder. + unflattened_shapes.push_back(repeated_dimension); + unflattened_shapes.push_back(step); + begins.push_back(const_0); + ends.push_back(end); + dimension_iter += num_repeats; + } + + // Process unrepeated labels, do not perform flatten or pads on dimensions. + std::vector unrepeated_indices_after_transpose = {dim_iota.begin() + dimension_iter, dim_iota.end()}; + const auto& unrepeated_dimensions_indices = + ov::op::v0::Constant::create(ov::element::i64, + {unrepeated_indices_after_transpose.size()}, + unrepeated_indices_after_transpose); + const auto unrepeated_dimensions = + std::make_shared(shapeof_transposed_data, unrepeated_dimensions_indices, const_0); + begins.insert(begins.end(), unrepeated_indices_after_transpose.size(), const_0); + ends.insert(ends.end(), unrepeated_indices_after_transpose.size(), const_0); + flattened_shapes.push_back(unrepeated_dimensions); + unflattened_shapes.push_back(unrepeated_dimensions); + + // Flatten the tensor to isolate diagonal elements + auto flatten_labels_shape_target = std::make_shared(flattened_shapes, 0); + auto flatten_labels = + std::make_shared(transpose_group_labels, flatten_labels_shape_target, false); + + // Pad the tensor to prepare for gathering diagonal elements + auto pad_begin = std::make_shared(begins, 0); + auto pad_end = std::make_shared(ends, 0); + auto pad = std::make_shared(flatten_labels, pad_begin, pad_end, ov::op::PadMode::CONSTANT); + + // Unflatten the tensor to restore the original shape with diagonal elements isolated + auto unflatten_labels_shape_target = std::make_shared(unflattened_shapes, 0); + auto unflatten_labels = std::make_shared(pad, unflatten_labels_shape_target, false); + + // Gather the diagonal elements + std::shared_ptr gather = unflatten_labels; + for (auto axis : reduced_axes) { + auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape({1}), {axis}); + gather = std::make_shared(gather, const_0, axis_const); } - // Convert to match type of data - auto multi_identity_cvt = makeOP({multi_identity, data}); - auto unreduced_diagonal = makeOP({data, multi_identity_cvt}, {{"auto_broadcast", "numpy"}}); - auto const_reduce_axes = makeConst(element::i64, ov::Shape({reduce_axes.size()}), reduce_axes); - auto diagonal = makeOP({unreduced_diagonal, const_reduce_axes}, {{"keep_dims", false}}); + + // Squeeze the tensor to remove the reduced dimensions + auto squeeze_reduced_axes = + ov::op::v0::Constant::create(ov::element::i64, ov::Shape({reduced_axes.size()}), reduced_axes); + auto diagonal = std::make_shared(gather, squeeze_reduced_axes); + return diagonal; } } // namespace + TEST_F(TransformationTestsF, Einsum_2in_matmul) { PartialShape data_shape_1{5, 2}; PartialShape data_shape_2{10, 1, 25}; @@ -324,17 +387,37 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { { using namespace ov::gen_pattern; auto data_1 = std::make_shared(element::f32, data_shape_1); - // If shapes are static, multi-identity can be constant-folded. - auto multi_identity = makeConst( - element::f32, - ov::Shape({1, 3, 1, 1, 3, 1}), - {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); - auto Multiply_1383 = makeOP({data_1, multi_identity}, {{"auto_broadcast", "numpy"}}); - auto Constant_1384 = makeConst(element::i64, ov::Shape({3}), {3, 4, 5}); - auto data_1_diagonal = makeOP({Multiply_1383, Constant_1384}, {{"keep_dims", false}}); + auto const_0 = makeConst(element::i64, ov::Shape({1}), {0}); + auto const_1 = makeConst(element::i64, ov::Shape({1}), {1}); + auto const_3 = makeConst(element::i64, ov::Shape({1}), {3}); + // Transpose data so repeated labels are grouped and unrepeated labels are moved to back. + // ij...iji -> iiijj... + auto transpose_diagonal = + makeOP({data_1, makeConst(element::i64, ov::Shape({6}), {0, 3, 5, 1, 4, 2})}); + // Flatten groups of repeated labels (ij...) + auto flatten_repeated_labels = + makeOP({transpose_diagonal, makeConst(element::i64, ov::Shape({3}), {1, 9, 2})}, + {{"special_zero", false}}); + // Pad begin and end are constant-folded. + auto pad_begin = makeConst(element::i64, ov::Shape({3}), {0, 0, 0}); + auto pad_end = makeConst(element::i64, ov::Shape({3}), {0, 3, 0}); + auto pad = makeOP( + {flatten_repeated_labels, pad_begin, pad_end, makeConst(element::f32, ov::Shape({}), {0})}, + {{"pad_mode", "constant"}}); + // unflatten padded groups of repeated labels so i(padded reminder of i)j(padded reminder of j)... + auto unflatten_repeated_labels = + makeOP({pad, makeConst(element::i64, ov::Shape({5}), {1, 1, 3, 4, 2})}, + {{"special_zero", false}}); + // Reduce padded dimensions to get diagonal. + auto reduce_first_repeat = + makeOP({unflatten_repeated_labels, const_0, const_1}, {{"batch_dims", 0}}); + auto reduce_second_repeat = + makeOP({reduce_first_repeat, const_0, const_3}, {{"batch_dims", 0}}); + auto remove_reduced_dims = + makeOP({reduce_second_repeat, makeConst(element::i64, ov::Shape({2}), {1, 3})}); // Transpose to the original order of output labels. auto Constant_1386 = makeConst(element::i64, ov::Shape({3}), {1, 2, 0}); - auto transpose_out = makeOP({data_1_diagonal, Constant_1386}); + auto transpose_out = makeOP({remove_reduced_dims, Constant_1386}); model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1}); } } @@ -384,52 +467,82 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_s auto node_0 = std::make_shared(element::f32, data_shape_3); auto node_2 = std::make_shared(element::f32, data_shape_2); auto node_4 = std::make_shared(element::f32, data_shape_1); - // ConstantFold folded multi-identity for input 2 to single constant - auto Multiply_1990 = makeConst(element::f32, ov::Shape({1, 1, 1, 1, 1, 1}), {1.000000f}); - // Extract diagonals - auto Multiply_1991 = makeOP({node_2, Multiply_1990}, {{"auto_broadcast", "numpy"}}); - auto Constant_1992 = makeConst(element::i64, ov::Shape({3}), {2, 3, 5}); - auto ReduceSum_1993 = makeOP({Multiply_1991, Constant_1992}, {{"keep_dims", false}}); - // Broadcast for ellipsis and labels constant folded to single constant and broadcast - auto Concat_2034 = makeConst(element::i64, ov::Shape({3}), {4, 3, 3}); - // Broadcast ellipsis and labels - auto Broadcast_2035 = makeOP({ReduceSum_1993, Concat_2034}, {{"mode", "bidirectional"}}); - auto Concat_2051 = makeConst(element::i64, ov::Shape({4}), {4, 3, 3, 1}); - auto Reshape_2052 = makeOP({Broadcast_2035, Concat_2051}, {{"special_zero", false}}); - auto Convert_1700 = makeConst(element::f32, ov::Shape({1, 1, 1, 1, 1, 1}), {1.000000f}); - auto Multiply_1701 = makeOP({node_4, Convert_1700}, {{"auto_broadcast", "numpy"}}); - auto Constant_1702 = makeConst(element::i64, ov::Shape({1}), {5}); - auto ReduceSum_1703 = makeOP({Multiply_1701, Constant_1702}, {{"keep_dims", false}}); - auto Constant_1799 = makeConst(element::i64, ov::Shape({1}), {1}); - auto ReduceSum_1800 = makeOP({ReduceSum_1703, Constant_1799}, {{"keep_dims", false}}); - auto Constant_1803 = makeConst(element::i64, ov::Shape({2}), {4, 5}); - auto Unsqueeze_1804 = makeOP({ReduceSum_1800, Constant_1803}); - auto Constant_1605 = makeConst(element::i64, ov::Shape({1}), {0}); - auto Unsqueeze_1606 = makeOP({node_0, Constant_1605}); - auto Constant_1607 = makeConst(element::i64, ov::Shape({2}), {0, 1}); - auto Unsqueeze_1608 = makeOP({Unsqueeze_1606, Constant_1607}); - auto Convert_1795 = makeConst( - element::f32, - ov::Shape({1, 1, 1, 1, 1, 3, 3}), - {1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f}); - auto Multiply_1796 = makeOP({Unsqueeze_1608, Convert_1795}, {{"auto_broadcast", "numpy"}}); - auto Constant_1797 = makeConst(element::i64, ov::Shape({1}), {6}); - auto ReduceSum_1798 = makeOP({Multiply_1796, Constant_1797}, {{"keep_dims", false}}); - auto Constant_1801 = makeConst(element::i64, ov::Shape({6}), {4, 0, 1, 2, 3, 5}); - auto Transpose_1802 = makeOP({ReduceSum_1798, Constant_1801}); - auto Multiply_1805 = makeOP({Unsqueeze_1804, Transpose_1802}, {{"auto_broadcast", "numpy"}}); - auto Constant_1994 = makeConst(element::i64, ov::Shape({6}), {0, 5, 1, 2, 3, 4}); - auto Transpose_1995 = makeOP({Multiply_1805, Constant_1994}); - auto Concat_2043 = makeConst(element::i64, ov::Shape({6}), {4, 3, 2, 1, 1, 3}); - auto Broadcast_2044 = makeOP({Transpose_1995, Concat_2043}, {{"mode", "bidirectional"}}); - auto Concat_2076 = makeConst(element::i64, ov::Shape({4}), {4, 3, 2, 3}); - auto Reshape_2077 = makeOP({Broadcast_2044, Concat_2076}, {{"special_zero", false}}); - auto MatMul_2116 = - makeOP({Reshape_2052, Reshape_2077}, {{"transpose_a", true}, {"transpose_b", true}}); - auto Concat_2117 = makeConst(element::i64, ov::Shape({5}), {4, 3, 2, 1, 1}); - auto Reshape_2118 = makeOP({MatMul_2116, Concat_2117}, {{"special_zero", false}}); - auto Constant_2119 = makeConst(element::i64, ov::Shape({5}), {1, 2, 3, 4, 0}); - auto node_6 = makeOP({Reshape_2118, Constant_2119}); + auto Constant_8230 = makeConst(element::i64, ov::Shape({6}), {1, 2, 3, 4, 5, 0}); + auto Transpose_8231 = makeOP({node_2, Constant_8230}); + auto Concat_8261 = makeConst(element::i64, ov::Shape({3}), {1, 1, 4}); + auto Reshape_8264 = makeOP({Transpose_8231, Concat_8261}, {{"special_zero", false}}); + auto Concat_8263 = makeConst(element::i64, ov::Shape({3}), {0, 0, 0}); + auto Concat_8262 = makeConst(element::i64, ov::Shape({3}), {0, 0, 0}); + auto Pad_8304 = + makeOP({Reshape_8264, Concat_8263, Concat_8262, 0.000000f}, {{"pad_mode", "constant"}}); + auto Concat_8381 = makeConst(element::i64, ov::Shape({5}), {1, 1, 1, 1, 4}); + auto Reshape_8382 = makeOP({Pad_8304, Concat_8381}, {{"special_zero", false}}); + auto Constant_8233 = makeConst(element::i64, ov::Shape({1}), {0}); + auto Constant_8459 = makeConst(element::i64, ov::Shape({1}), {1}); + auto Gather_8460 = makeOP({Reshape_8382, Constant_8233, Constant_8459}, {{"batch_dims", 0}}); + auto Constant_8461 = makeConst(element::i64, ov::Shape({1}), {3}); + auto Gather_8462 = makeOP({Gather_8460, Constant_8233, Constant_8461}, {{"batch_dims", 0}}); + auto Constant_8463 = makeConst(element::i64, ov::Shape({2}), {1, 3}); + auto Squeeze_8464 = makeOP({Gather_8462, Constant_8463}); + auto Constant_8465 = makeConst(element::i64, ov::Shape({3}), {0, 2, 1}); + auto Transpose_8466 = makeOP({Squeeze_8464, Constant_8465}); + auto Constant_8494 = makeConst(element::i64, ov::Shape({3}), {3, 4, 3}); + auto Broadcast_8495 = makeOP({Transpose_8466, Constant_8494}, {{"mode", "bidirectional"}}); + auto Constant_8528 = makeConst(element::i64, ov::Shape({4}), {3, 4, 1, 3}); + auto Reshape_8529 = makeOP({Broadcast_8495, Constant_8528}, {{"special_zero", false}}); + auto Constant_7971 = makeConst(element::i64, ov::Shape({6}), {0, 5, 1, 2, 3, 4}); + auto Transpose_7972 = makeOP({node_4, Constant_7971}); + auto Concat_7990 = makeConst(element::i64, ov::Shape({5}), {1, 2, 2, 1, 1}); + auto Reshape_7993 = makeOP({Transpose_7972, Concat_7990}, {{"special_zero", false}}); + auto Concat_7992 = makeConst(element::i64, ov::Shape({5}), {0, 0, 0, 0, 0}); + auto Concat_7991 = makeConst(element::i64, ov::Shape({5}), {0, 0, 0, 0, 0}); + auto Pad_8014 = + makeOP({Reshape_7993, Concat_7992, Concat_7991, 0.000000f}, {{"pad_mode", "constant"}}); + auto Concat_8053 = makeConst(element::i64, ov::Shape({6}), {1, 1, 2, 2, 1, 1}); + auto Reshape_8054 = makeOP({Pad_8014, Concat_8053}, {{"special_zero", false}}); + auto Constant_7974 = makeConst(element::i64, ov::Shape({1}), {0}); + auto Constant_8093 = makeConst(element::i64, ov::Shape({1}), {1}); + auto Gather_8094 = makeOP({Reshape_8054, Constant_7974, Constant_8093}, {{"batch_dims", 0}}); + auto Constant_8095 = makeConst(element::i64, ov::Shape({1}), {1}); + auto Squeeze_8096 = makeOP({Gather_8094, Constant_8095}); + auto Constant_8223 = makeConst(element::i64, ov::Shape({1}), {1}); + auto ReduceSum_8224 = makeOP({Squeeze_8096, Constant_8223}, {{"keep_dims", false}}); + auto Constant_8227 = makeConst(element::i64, ov::Shape({2}), {4, 5}); + auto Unsqueeze_8228 = makeOP({ReduceSum_8224, Constant_8227}); + auto Constant_7967 = makeConst(element::i64, ov::Shape({1}), {0}); + auto Unsqueeze_7968 = makeOP({node_0, Constant_7967}); + auto Constant_7969 = makeConst(element::i64, ov::Shape({2}), {0, 1}); + auto Unsqueeze_7970 = makeOP({Unsqueeze_7968, Constant_7969}); + auto Constant_8097 = makeConst(element::i64, ov::Shape({7}), {5, 6, 0, 1, 2, 3, 4}); + auto Transpose_8098 = makeOP({Unsqueeze_7970, Constant_8097}); + auto Concat_8116 = makeConst(element::i64, ov::Shape({6}), {9, 1, 1, 1, 3, 1}); + auto Reshape_8119 = makeOP({Transpose_8098, Concat_8116}, {{"special_zero", false}}); + auto Concat_8118 = makeConst(element::i64, ov::Shape({6}), {0, 0, 0, 0, 0, 0}); + auto Concat_8117 = makeConst(element::i64, ov::Shape({6}), {3, 0, 0, 0, 0, 0}); + auto Pad_8140 = + makeOP({Reshape_8119, Concat_8118, Concat_8117, 0.000000f}, {{"pad_mode", "constant"}}); + auto Concat_8179 = makeConst(element::i64, ov::Shape({7}), {3, 4, 1, 1, 1, 3, 1}); + auto Reshape_8180 = makeOP({Pad_8140, Concat_8179}, {{"special_zero", false}}); + auto Constant_8100 = makeConst(element::i64, ov::Shape({1}), {0}); + auto Constant_8219 = makeConst(element::i64, ov::Shape({1}), {1}); + auto Gather_8220 = makeOP({Reshape_8180, Constant_8100, Constant_8219}, {{"batch_dims", 0}}); + auto Constant_8221 = makeConst(element::i64, ov::Shape({1}), {1}); + auto Squeeze_8222 = makeOP({Gather_8220, Constant_8221}); + auto Constant_8225 = makeConst(element::i64, ov::Shape({6}), {5, 1, 2, 3, 0, 4}); + auto Transpose_8226 = makeOP({Squeeze_8222, Constant_8225}); + auto Multiply_8229 = makeOP({Unsqueeze_8228, Transpose_8226}, {{"auto_broadcast", "numpy"}}); + auto Constant_8467 = makeConst(element::i64, ov::Shape({6}), {4, 0, 1, 2, 3, 5}); + auto Transpose_8468 = makeOP({Multiply_8229, Constant_8467}); + auto Constant_8502 = makeConst(element::i64, ov::Shape({6}), {3, 4, 2, 1, 1, 3}); + auto Broadcast_8503 = makeOP({Transpose_8468, Constant_8502}, {{"mode", "bidirectional"}}); + auto Constant_8573 = makeConst(element::i64, ov::Shape({4}), {3, 4, 2, 3}); + auto Reshape_8574 = makeOP({Broadcast_8503, Constant_8573}, {{"special_zero", false}}); + auto MatMul_8575 = + makeOP({Reshape_8529, Reshape_8574}, {{"transpose_a", false}, {"transpose_b", true}}); + auto Constant_8577 = makeConst(element::i64, ov::Shape({5}), {3, 4, 2, 1, 1}); + auto Reshape_8578 = makeOP({MatMul_8575, Constant_8577}, {{"special_zero", false}}); + auto Constant_8579 = makeConst(element::i64, ov::Shape({5}), {0, 2, 3, 4, 1}); + auto node_6 = makeOP({Reshape_8578, Constant_8579}); model_ref = std::make_shared(NodeVector{node_6}, ParameterVector{node_4, node_2, node_0}); } } @@ -462,36 +575,41 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d auto data3_insert_missing_ellipsis = makeOP({data_3, ellipsis_idx}); auto align_ellipsis_idx = makeConst(element::i64, ov::Shape({2}), {0, 1}); auto data_3_processed = makeOP({data3_insert_missing_ellipsis, align_ellipsis_idx}); - auto data_3_diagonal = extract_diagonal(data_3_processed, {{5, 6}}); + auto data_3_diagonal = extract_diagonal(data_3_processed, {{5, 6}}, {0, 1, 2, 3, 4}); // No reduced labels - use simplified subgraph that uses Multiply instead Matmul - auto convenient_layout = makeConst(element::i64, ov::Shape({6}), {0, 1, 2, 4, 3, 5}); - // ...dbc -> ...bdc + // c...db -> ...bcd + auto convenient_layout = makeConst(element::i64, ov::Shape({6}), {1, 2, 3, 5, 0, 4}); auto rhs_convenient_layout = makeOP({data_3_diagonal, convenient_layout}); // Optionally unsqueeze both operands for elementwise-multiplication with broadcasting // For LHS operand, unsqueeze at RHS separate dimensions indices (placed at end of RHS by transpose) auto lhs_unsqueeze_dims = makeConst(element::i64, ov::Shape({2}), {4, 5}); auto lhs_unsqueeze = makeOP({data_1_processed, lhs_unsqueeze_dims}); // Out subscript = LHS_subscript + RHS_separate_part_subscript - // ...bdc = ...b + dc + // ...bcd = ...b + cd auto data_1_3 = makeOP({lhs_unsqueeze, rhs_convenient_layout}, {{"auto_broadcast", "numpy"}}); // Second pair of einsum inputs - data_2 and result of the first pair - // bcccdd,...bdc->c...b + // bcccdd,...bcd->c...b // data_2 - handle repeated labels auto data_2_diagonal = extract_diagonal(data_2, { {1, 2, 3}, // indices of repeated label c {4, 5}, // indices_of_repeated_label_d + }, + { + {0}, // indices of unrepeated label b }); + // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. + auto data_2_processed = + makeOP({data_2_diagonal, makeConst(element::i64, ov::Shape({3}), {0, 2, 1})}); // data_1_3 - transpose to correctly group common, separate and reduced labels - // ...bdc->bc...d - auto transpose_data_1_3_target = makeConst(element::i64, ov::Shape({6}), {3, 5, 0, 1, 2, 4}); + auto transpose_data_1_3_target = makeConst(element::i64, ov::Shape({6}), {4, 3, 0, 1, 2, 5}); auto data_1_3_processed = makeOP({data_1_3, transpose_data_1_3_target}); // Extract and broadcast common subshapes (bc) auto shapeof_data_1_3 = makeOP({data_1_3_processed}, {{"output_type", "i64"}}); auto common_data_1_3 = extract_subshape_from_shape(shapeof_data_1_3, 0, 2); - auto shapeof_data_2 = makeOP({data_2_diagonal}, {{"output_type", "i64"}}); + auto shapeof_data_2 = makeOP({data_2_processed}, {{"output_type", "i64"}}); auto common_data_2 = extract_subshape_from_shape(shapeof_data_2, 0, 2); auto common_broadcast_merge_shapes = broadcast_merge_shapes(common_data_2, common_data_1_3); @@ -507,7 +625,7 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d auto broadcast_data_2_target = makeOP({common_broadcast_merge_shapes, reduced_broadcast_merge_shapes}, {{"axis", 0}}); auto broadcast_data_2 = - makeOP({data_2_diagonal, broadcast_data_2_target}, {{"mode", "bidirectional"}}); + makeOP({data_2_processed, broadcast_data_2_target}, {{"mode", "bidirectional"}}); auto broadcast_data_1_3_target = makeOP({common_broadcast_merge_shapes, separate_data_1_3, reduced_broadcast_merge_shapes}, {{"axis", 0}}); @@ -520,7 +638,7 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d // Reshape data_2 auto separate_data_2_placeholder = makeConst(element::i64, ov::Shape({1}), {1}); auto reshape_data_2_target = - makeOP({common_broadcast_merge_shapes, reduced_prod, separate_data_2_placeholder}, + makeOP({common_broadcast_merge_shapes, separate_data_2_placeholder, reduced_prod}, {{"axis", 0}}); auto reshape_data_2 = makeOP({broadcast_data_2, reshape_data_2_target}, {{"special_zero", false}}); @@ -534,11 +652,11 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d auto reshape_data_1_3 = makeOP({broadcast_data_1_3, reshape_data_1_3_target}, {{"special_zero", false}}); auto matmul = - makeOP({reshape_data_2, reshape_data_1_3}, {{"transpose_a", true}, {"transpose_b", true}}); + makeOP({reshape_data_2, reshape_data_1_3}, {{"transpose_a", false}, {"transpose_b", true}}); auto reshape_out_subshape = makeOP({common_broadcast_merge_shapes, separate_data_1_3}, {{"axis", 0}}); auto reshape_out = makeOP({matmul, reshape_out_subshape}, {{"special_zero", false}}); - auto Constant_1965 = makeConst(element::i64, ov::Shape({5}), {1, 2, 3, 4, 0}); + auto Constant_1965 = makeConst(element::i64, ov::Shape({5}), {0, 2, 3, 4, 1}); auto transpose_out = makeOP({reshape_out, Constant_1965}); model_ref = std::make_shared(NodeVector{transpose_out}, ParameterVector{data_1, data_2, data_3}); } From 98c356fe65257d617c94d373c36eb2ec9513ca06 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Wed, 19 Feb 2025 17:54:00 +0100 Subject: [PATCH 43/47] Fix CI issue Signed-off-by: Mateusz Mikolajczyk --- .../tests/op_conversions/einsum_decomposition_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 2e8a4105bc26b6..27364bd1c3bdc7 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -598,7 +598,7 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d {4, 5}, // indices_of_repeated_label_d }, { - {0}, // indices of unrepeated label b + 0, // indices of unrepeated label b }); // Transpose data_2 so that common labels, separated and reduced labels are grouped for both operands. auto data_2_processed = From d943aee2a93602af94fa0eecf86f27dffee6daf2 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 20 Feb 2025 15:54:59 +0100 Subject: [PATCH 44/47] Update src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp Co-authored-by: Katarzyna Mitrus --- .../transformations/op_conversions/einsum_decomposition.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index edf091deebcaa3..928303144ff482 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -207,8 +207,7 @@ LabelDimMap compute_label_dim_map(const ov::Rank& input_rank, const std::string& resulted_map[label].push_back(current_dim); ++current_dim; } else { - std::vector label_dims; - label_dims.push_back(current_dim); + std::vector label_dims{current_dim}; resulted_map[label] = label_dims; ++current_dim; } From 4b7627aaf676c41105a413ab67a624a47ec54408 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 20 Feb 2025 16:48:15 +0100 Subject: [PATCH 45/47] Compare accuracy in decomposition tests Signed-off-by: Mateusz Mikolajczyk --- .../tests/op_conversions/einsum_decomposition_test.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp index 27364bd1c3bdc7..bcd5cabd37670f 100644 --- a/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/einsum_decomposition_test.cpp @@ -161,6 +161,7 @@ std::shared_ptr extract_diagonal(const std::shared_ptr& data TEST_F(TransformationTestsF, Einsum_2in_matmul) { PartialShape data_shape_1{5, 2}; PartialShape data_shape_2{10, 1, 25}; + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); @@ -211,6 +212,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul) { TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { PartialShape data_shape_1 = PartialShape::dynamic(2); PartialShape data_shape_2 = PartialShape::dynamic(3); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); @@ -293,6 +295,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_dynamic) { TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { PartialShape data_shape_1 = PartialShape::dynamic(2); PartialShape data_shape_2 = PartialShape::dynamic(5); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); @@ -377,6 +380,7 @@ TEST_F(TransformationTestsF, Einsum_2in_matmul_ellipsis_dynamic) { TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { Shape data_shape_1 = {1, 3, 2, 1, 3, 1}; + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto einsum = std::make_shared(OutputVector{data_1}, "ij...iji->j...i"); @@ -424,6 +428,7 @@ TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_ellipsis_static_cf) { TEST_F(TransformationTestsF, Einsum_1in_repeated_labels_empty_ellipsis_dynamic) { PartialShape data_shape_1 = PartialShape::dynamic(5); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto einsum = std::make_shared(OutputVector{data_1}, "ij...iji->j...i"); @@ -452,6 +457,7 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_s PartialShape data_shape_1 = {1, 2, 2, 1, 1, 1}; PartialShape data_shape_2 = {4, 1, 1, 1, 1, 1}; PartialShape data_shape_3 = {3, 1, 3, 3}; + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); @@ -551,6 +557,7 @@ TEST_F(TransformationTestsF, Einsum_3in_broadcast_duplicated_ellipsis_repeated_d PartialShape data_shape_1 = PartialShape::dynamic(5); PartialShape data_shape_2 = PartialShape::dynamic(6); PartialShape data_shape_3 = PartialShape::dynamic(4); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); { auto data_1 = std::make_shared(element::f32, data_shape_1); auto data_2 = std::make_shared(element::f32, data_shape_2); From eff7dfeb8841897c99ed54ffb78a321d39d58c1f Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 20 Feb 2025 16:57:12 +0100 Subject: [PATCH 46/47] Add const_0 to subgraph nodes Signed-off-by: Mateusz Mikolajczyk --- .../transformations/op_conversions/einsum_decomposition.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index edf091deebcaa3..8c61660cc5a93f 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -480,10 +480,10 @@ ov::Output reshape_input_for_matmul(const ov::Output& input_ subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, separate_shape_prod}); } ov::OutputVector reduced_sub_shape_prod; - auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); if (reduced_sub_shape.size() > 0) { + auto const_0 = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); auto product = std::make_shared(reduced_sub_shape[0], const_0, true); - subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, product}); + subgraph_nodes.insert(subgraph_nodes.end(), {reduce_axis_const, const_0, product}); reduced_sub_shape_prod.push_back(product->output(0)); } From 4781c211a1e77c69cbd289d523b01c9dbb838eb8 Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Thu, 20 Feb 2025 17:18:51 +0100 Subject: [PATCH 47/47] Apply requested change to compute_ranges Signed-off-by: Mateusz Mikolajczyk --- .../op_conversions/einsum_decomposition.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp index 8c61660cc5a93f..f3acdc9ef34f7e 100644 --- a/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/einsum_decomposition.cpp @@ -250,24 +250,26 @@ void compute_ranges(const ov::Rank& input_rank, size_t& reduced_end, bool is_separated_first) { auto label_to_dim_map = compute_label_dim_map(input_rank, input_subscript); - static const std::string ellipsis = "..."; size_t common_rank = common_labels.size(); - if (std::find(common_labels.begin(), common_labels.end(), ellipsis) != common_labels.end()) { - OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); - common_rank += label_to_dim_map[ellipsis].size() - 1; - } - size_t sep_rank = sep_labels.size(); - if (std::find(sep_labels.begin(), sep_labels.end(), ellipsis) != sep_labels.end()) { - OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); - sep_rank += label_to_dim_map[ellipsis].size() - 1; - } - size_t reduced_rank = reduced_labels.size(); - if (std::find(reduced_labels.begin(), reduced_labels.end(), ellipsis) != reduced_labels.end()) { - OPENVINO_ASSERT(label_to_dim_map.find(ellipsis) != label_to_dim_map.end()); - reduced_rank += label_to_dim_map[ellipsis].size() - 1; + + static const std::string ellipsis = "..."; + // Adjust rank to include ellipsis dimensions. + // Initial rank is the number of labels in the input subscript, so if the ellipsis is present, initial ellipsis rank + // would be counted as 1. Adjust the rank to include actual ellipsis rank with accounting for existing "placeholder" + // by subtracting by 1. + if (label_to_dim_map.find(ellipsis) != label_to_dim_map.end()) { + if (std::find(common_labels.begin(), common_labels.end(), ellipsis) != common_labels.end()) { + common_rank += label_to_dim_map[ellipsis].size() - 1; + } + if (std::find(sep_labels.begin(), sep_labels.end(), ellipsis) != sep_labels.end()) { + sep_rank += label_to_dim_map[ellipsis].size() - 1; + } + if (std::find(reduced_labels.begin(), reduced_labels.end(), ellipsis) != reduced_labels.end()) { + reduced_rank += label_to_dim_map[ellipsis].size() - 1; + } } common_begin = 0;