Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Einsum Core and common transformation to support broadcasting, repeated labels and ellipsis #28151

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
637d510
Einsum core improvements
mmikolajcz Dec 18, 2024
50c98c1
Einsum decomposition broadcasting + ellipsis support
mmikolajcz Dec 19, 2024
d3eac20
Move broadcasting out of reshape conditional
mmikolajcz Dec 20, 2024
46098cc
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 2, 2025
0ec7974
Initial support for repeated labels
mmikolajcz Jan 7, 2025
be601ca
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 8, 2025
fa041ca
Remove xfail for onnx einsum test
mmikolajcz Jan 8, 2025
6796536
Remove Einsum xfail for torch HF tests
mmikolajcz Jan 8, 2025
be8400c
Update transpose reshape elimination for MatMul to handle broadcast f…
mmikolajcz Jan 16, 2025
d8147d5
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 17, 2025
81b5d39
Initial Einsum update to handle ellipsis label without dimensions
mmikolajcz Jan 17, 2025
28b579a
Update reduce_input in einsum common decomposition
mmikolajcz Jan 20, 2025
33acf2e
Fix broadcasting of reduced part for reshape
mmikolajcz Jan 20, 2025
29b1072
Extend Einsum reference test cases
mmikolajcz Jan 21, 2025
9e749f9
FIx divide by 0 and handling 2+ repeated label types for einsum decom…
mmikolajcz Jan 23, 2025
50b6d3e
Move fix_inputs_with_0d_ellipsis to separate function
mmikolajcz Jan 23, 2025
f666700
Modify reshape_input_for_matmul reduced prod to match ne for separate
mmikolajcz Jan 23, 2025
6347ed2
Refactor empty ellipsis handling in Einsum decomposition to improve c…
mmikolajcz Jan 23, 2025
6f1732f
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Jan 23, 2025
d380155
Refactor handling of 0-dimensional ellipsis in Einsum operations for …
mmikolajcz Jan 24, 2025
1918991
Refactor broadcast_merge_shapes to eliminate loop
mmikolajcz Jan 24, 2025
2eee35c
Fix shape_infer for reduced out ellipsis with dynamic rank inputs
mmikolajcz Jan 30, 2025
6d53d09
Implement unsqueeze_ellipses_to_same_rank function for consistent ell…
mmikolajcz Feb 6, 2025
a73ca6f
Implement requested changes to increase clarity
mmikolajcz Feb 6, 2025
9b1a08e
FIx assert in unsqueeze_ellipses_to_same_rank
mmikolajcz Feb 7, 2025
ab92426
Add einsum decomposition test cases + minor decomposition improvements
mmikolajcz Feb 10, 2025
cd422c2
Add missing docstrings for einsum decomposition
mmikolajcz Feb 10, 2025
9e3153e
Fix extract diagonal call
mmikolajcz Feb 10, 2025
a6cbfd4
Remove dependency on einsum_decompose_ptr
mmikolajcz Feb 10, 2025
0726b36
Minor change to remove duplicated converts in favor of single convert…
mmikolajcz Feb 10, 2025
4fe01e8
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 10, 2025
1d0276d
Fix callback lambda
mmikolajcz Feb 11, 2025
e3867a2
Improve redability for first two einsum decomposition test cases
mmikolajcz Feb 11, 2025
c75fa47
Improve redability for einsum decomposition test 3
mmikolajcz Feb 12, 2025
f4cb1b3
Improve redability of einsum decomposition test 4
mmikolajcz Feb 12, 2025
509e302
Improve redability for einsum decomposition test with duplicated label
mmikolajcz Feb 12, 2025
09533e4
Extract subshape extraction to separate function for redability
mmikolajcz Feb 12, 2025
83e0317
Extract broadcast_merge_shapes to separate function
mmikolajcz Feb 12, 2025
a5e46dd
Add helper for diagonal extraction for redability
mmikolajcz Feb 12, 2025
6c89a9e
Fix const formatting
mmikolajcz Feb 12, 2025
3cf8b83
Improve redability for einsum decomposition test
mmikolajcz Feb 12, 2025
a5d5db4
Modify einsum broadcast_merge_shapes to use Maximum
mmikolajcz Feb 13, 2025
8c44729
Fix typo in einsum decomposition
mmikolajcz Feb 13, 2025
e273a50
Refactor handling of repeated labels in einsum decomposition to not c…
mmikolajcz Feb 18, 2025
3fc4a70
Remove unnecesary loop
mmikolajcz Feb 18, 2025
153aaaa
Fix einsum decomposition
mmikolajcz Feb 19, 2025
cd7658b
Update einsum decomposition test + add inline comments with descriptions
mmikolajcz Feb 19, 2025
98c356f
Fix CI issue
mmikolajcz Feb 19, 2025
5519f2d
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 19, 2025
00a52b8
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
mmikolajcz Feb 20, 2025
d943aee
Update src/common/transformations/src/transformations/op_conversions/…
mmikolajcz Feb 20, 2025
4b7627a
Compare accuracy in decomposition tests
mmikolajcz Feb 20, 2025
eff7dfe
Add const_0 to subgraph nodes
mmikolajcz Feb 20, 2025
4781c21
Apply requested change to compute_ranges
mmikolajcz Feb 20, 2025
cfe5928
Merge branch 'mateuszm/einsum/core' of https://github.com/mmikolajcz/…
mmikolajcz Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

namespace {
Expand Down Expand Up @@ -124,9 +126,16 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa
auto transpose_before_pattern =
ov::pass::pattern::wrap_type<ov::op::v1::Transpose>({input_2_pattern, const_transpose_before_pattern});

auto const_optional_broadcast_before_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto optional_broadcast_before_pattern = ov::pass::pattern::wrap_type<ov::op::v3::Broadcast>(
{transpose_before_pattern, const_optional_broadcast_before_pattern});

auto transpose_or_transpose_broadcast = std::make_shared<ov::pass::pattern::op::Or>(
OutputVector{transpose_before_pattern, optional_broadcast_before_pattern});

auto const_reshape_before_pattern = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto reshape_before_pattern =
ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({transpose_before_pattern, const_reshape_before_pattern});
auto reshape_before_pattern = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>(
{transpose_or_transpose_broadcast, const_reshape_before_pattern});

auto matmul_pattern = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({input_1_pattern, reshape_before_pattern});

Expand Down Expand Up @@ -181,8 +190,37 @@ ov::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMa
// transposes
if (!check_transposes(transpose_before_order, transpose_after_order, transposed_b))
return false;

const auto new_matmul = std::make_shared<ov::op::v0::MatMul>(input_1, input_2, transposed_a, false);
auto matmul_2_input = input_2;
// for einsum decomposition, check if broadcast exist and if so, reorder target shape based on transpose
if (pattern_value_map.count(optional_broadcast_before_pattern)) {
auto broadcast_before = ov::as_type_ptr<ov::op::v3::Broadcast>(
pattern_value_map.at(optional_broadcast_before_pattern).get_node_shared_ptr());
if (!broadcast_before) {
return false;
}
auto broadcast_before_constant =
ov::as_type_ptr<ov::op::v0::Constant>(broadcast_before->get_input_node_shared_ptr(1));
if (!broadcast_before_constant) {
return false;
}
auto broadcast_shape_after_transpose = broadcast_before_constant->cast_vector<int64_t>();
if (broadcast_shape_after_transpose.size() != transpose_before_order.size()) {
return false;
}
std::vector<int64_t> broadcast_shape_no_transpose;
broadcast_shape_no_transpose.reserve(broadcast_shape_after_transpose.size());
for (auto idx : transpose_before_order) {
broadcast_shape_no_transpose.push_back(broadcast_shape_after_transpose[idx]);
}
auto broadcast_shape_no_transpose_constant =
ov::op::v0::Constant::create(element::i64,
broadcast_before_constant->get_shape(),
broadcast_shape_no_transpose);
matmul_2_input = broadcast_before->clone_with_new_inputs({input_2, broadcast_shape_no_transpose_constant});
copy_runtime_info(broadcast_before, matmul_2_input.get_node_shared_ptr());
}

const auto new_matmul = std::make_shared<ov::op::v0::MatMul>(input_1, matmul_2_input, transposed_a, false);
new_matmul->set_friendly_name(transpose_after->get_friendly_name());
copy_runtime_info({transpose_before, reshape_before, matmul, reshape_after, transpose_after}, new_matmul);
replace_node(transpose_after, new_matmul);
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,21 @@ TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_Einsum) {
{
auto data_1 = std::make_shared<ov::op::v0::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<ov::op::v0::Parameter>(element::f32, data_shape_2);
auto broadcast_shape_constant_1 =
std::make_shared<ov::op::v0::Constant>(element::i64, Shape{data_shape_1.size()}, data_shape_1);
auto broadcast_shape_constant_2 =
std::make_shared<ov::op::v0::Constant>(element::i64, Shape{data_shape_2.size()}, data_shape_2);
auto broadcast_1 = std::make_shared<ov::op::v3::Broadcast>(data_1,
broadcast_shape_constant_1,
ov::op::BroadcastType::BIDIRECTIONAL);
auto broadcast_2 = std::make_shared<ov::op::v3::Broadcast>(data_2,
broadcast_shape_constant_2,
ov::op::BroadcastType::BIDIRECTIONAL);
// for some cases Reshape may be first input for Matmul
auto shape_constant =
std::make_shared<ov::op::v0::Constant>(element::i64, Shape{data_shape_1.size()}, data_shape_1);
auto reshape = std::make_shared<ov::op::v1::Reshape>(data_1, shape_constant, false);
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, data_2, false, false);
auto reshape = std::make_shared<ov::op::v1::Reshape>(broadcast_1, shape_constant, false);
auto matmul = std::make_shared<ov::op::v0::MatMul>(reshape, broadcast_2, false, false);
model_ref = std::make_shared<Model>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}
Loading
Loading