From 58766e7c7606b38767710429daf8fb11e147e55b Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 17 Jan 2025 18:04:03 +0400 Subject: [PATCH] [Snippets] Implemented SetDynamicWAToOuterMostLoop pass (#28505) ### Details: - *Dynamic MHA Subgraphs may have only dynamic batch. Then the pass `MHAParallelWAOptimizer` cannot be applied to this subgraph to increase parallel work amount since outermost Loop By M in MHA has static work amount. Then Subgraph may be inefficiently executed. This PR implemented the pass `SetDynamicWAToOuterMostLoop ` which sets dynamic work amount to outmost Loop in dynamic MHA to make applicable `MHAParallelWAOptimizer` in runtime.* ### Tickets: - *160647* --- .../pass/mha_parallel_wa_optimizer.hpp | 9 ++- .../pass/set_dynamic_wa_to_outermost_loop.hpp | 30 ++++++++ .../pass/mha_parallel_wa_optimizer.cpp | 8 +- .../pass/set_dynamic_wa_to_outermost_loop.cpp | 73 +++++++++++++++++++ src/common/snippets/src/op/subgraph.cpp | 2 + .../snippets/mha_wo_transpose.cpp | 5 ++ 6 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 src/common/snippets/include/snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp create mode 100644 src/common/snippets/src/lowered/pass/set_dynamic_wa_to_outermost_loop.cpp diff --git a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp index 2f42a523ec4eac..7a49f5942f1db2 100644 --- a/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/mha_parallel_wa_optimizer.hpp @@ -12,6 +12,8 @@ namespace ov { namespace snippets { namespace lowered { namespace pass { + +class SetDynamicWAToOuterMostLoop; /** * @class MHAParallelWAOptimizer * @brief Optimizes the dynamic MHA execution increasing parallel work amount dy dividing Brgemm's "M" dimension to "parallel_m" @@ -22,6 +24,7 @@ namespace pass { * - Determines loops that should be adjusted. */ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { + friend class SetDynamicWAToOuterMostLoop; public: OPENVINO_RTTI("MHAParallelWAOptimizer", "", RuntimeOptimizer) MHAParallelWAOptimizer() = default; @@ -31,10 +34,14 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer { bool applicable() const override { return !m_loops_to_split.empty(); } private: - static std::unordered_set find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir); + static std::unordered_set find_applicable_brgemms( + const lowered::LinearIRCPtr& linear_ir, + bool check_dynamic_wa = true); + static std::unordered_set find_unsqueezed_params( const lowered::LinearIRCPtr& linear_ir, const std::unordered_set& brgemms); + static std::vector find_loops_to_split( const lowered::LinearIRCPtr& linear_ir, const std::unordered_set& unsqueezed_params); diff --git a/src/common/snippets/include/snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp b/src/common/snippets/include/snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp new file mode 100644 index 00000000000000..6daeb97ec8c566 --- /dev/null +++ b/src/common/snippets/include/snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "pass.hpp" + +namespace ov { +namespace snippets { +namespace lowered { +namespace pass { + +/** + * @interface SetDynamicWAToOuterMostLoop + * @brief The pass set dynamic work amount to outermost Loop by M in dynamic MHA Subgraphs + * to allow MHAParallelWAOptimizer optimizes parallel work amount in runtime. + * @ingroup snippets + */ +class SetDynamicWAToOuterMostLoop : public Pass { +public: + OPENVINO_RTTI("SetDynamicWAToOuterMostLoop", "", Pass); + SetDynamicWAToOuterMostLoop() = default; + bool run(LinearIR& linear_ir) override; +}; + +} // namespace pass +} // namespace lowered +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp index c75d1e86abbfa5..bb01346f4eff7d 100644 --- a/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp +++ b/src/common/snippets/src/lowered/pass/mha_parallel_wa_optimizer.cpp @@ -85,7 +85,9 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) { return true; } -std::unordered_set MHAParallelWAOptimizer::find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir) { +std::unordered_set MHAParallelWAOptimizer::find_applicable_brgemms( + const lowered::LinearIRCPtr& linear_ir, + bool check_dynamic_wa) { auto is_brgemm = [](const lowered::ExpressionPtr& expr) { return ov::is_type(expr->get_node()); }; @@ -96,12 +98,12 @@ std::unordered_set MHAParallelWAOptimizer::find_applicab brgemm_it = std::find_if(std::next(brgemm_it), linear_ir->end(), is_brgemm); } const auto& loop_manager = linear_ir->get_loop_manager(); - auto applicable_brgemm = [&loop_manager](const lowered::ExpressionPtr& expr) { + auto applicable_brgemm = [&loop_manager, check_dynamic_wa](const lowered::ExpressionPtr& expr) { const auto& loop_idces = expr->get_loop_ids(); if (loop_idces.empty()) return false; const auto& outermost_loop = loop_manager->get_loop_info(loop_idces[0]); - if (!snippets::utils::is_dynamic_value(outermost_loop->get_work_amount())) + if (check_dynamic_wa && !snippets::utils::is_dynamic_value(outermost_loop->get_work_amount())) return false; bool loop_by_m = true; outermost_loop->iterate_through_ports([&loop_by_m](const lowered::LoopPort& port) { diff --git a/src/common/snippets/src/lowered/pass/set_dynamic_wa_to_outermost_loop.cpp b/src/common/snippets/src/lowered/pass/set_dynamic_wa_to_outermost_loop.cpp new file mode 100644 index 00000000000000..8a5db80f577aee --- /dev/null +++ b/src/common/snippets/src/lowered/pass/set_dynamic_wa_to_outermost_loop.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2023-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp" + +#include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp" +#include "snippets/itt.hpp" +#include "snippets/lowered/linear_ir.hpp" +#include "snippets/lowered/loop_manager.hpp" +#include "snippets/op/brgemm.hpp" +#include "snippets/utils/loop_utils.hpp" + +namespace ov { +namespace snippets { +namespace lowered { +namespace pass { + +bool SetDynamicWAToOuterMostLoop::run(LinearIR& linear_ir) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetDynamicWAToOuterMostLoop") + if (linear_ir.empty() || !linear_ir.is_dynamic() || linear_ir.get_config().m_enable_domain_optimization) + return false; + + const auto linear_ir_ptr = std::make_shared(linear_ir); + const auto brgemms = MHAParallelWAOptimizer::find_applicable_brgemms(linear_ir_ptr, false); + if (brgemms.empty()) + return false; + + const auto unsqueezed_params = MHAParallelWAOptimizer::find_unsqueezed_params(linear_ir_ptr, brgemms); + OPENVINO_ASSERT(!unsqueezed_params.empty(), "unsqueezed_params mustn't be empty after initialization"); + + + const auto& loop_manager = linear_ir_ptr->get_loop_manager(); + std::unordered_set affected_loops; + size_t prev_loop_id = std::numeric_limits::max(); + static const size_t dim_M_idx = 1; + + auto add_affected_loop = [&](const lowered::ExpressionPtr& expr) { + const auto& loop_idces = expr->get_loop_ids(); + if (loop_idces.empty() || loop_idces.front() == prev_loop_id) + return; + + prev_loop_id = loop_idces.front(); + const auto loop_info = loop_manager->get_loop_info(prev_loop_id); + if (loop_info->get_dim_idx() == dim_M_idx) { + affected_loops.insert(loop_info); + } + }; + + size_t i = 0; + std::unordered_set visited; + for (const auto& param : linear_ir_ptr->get_parameters()) { + if (unsqueezed_params.count(i++)) + continue; + utils::visit_path(param, visited, add_affected_loop, false); + } + + bool modified = false; + for (const auto& loop : affected_loops) { + if (!utils::is_dynamic_value(loop->get_work_amount())) { + loop->set_work_amount(utils::get_dynamic_value()); + ov::snippets::utils::update_data_pointer_shifts(loop); + modified = true; + } + } + + return modified; +} + +} // namespace pass +} // namespace lowered +} // namespace snippets +} // namespace ov \ No newline at end of file diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index ecfa72bcb20919..42820889e2f63f 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -54,6 +54,7 @@ #include "snippets/lowered/pass/validate_expanded_loops.hpp" #include "snippets/lowered/pass/set_load_store_scalar.hpp" #include "snippets/lowered/pass/extract_loop_invariants.hpp" +#include "snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp" #include "snippets/lowered/pass/init_registers.hpp" @@ -467,6 +468,7 @@ void Subgraph::control_flow_transformations(size_t min_parallel_work_amount, siz pipeline.register_pass(); pipeline.register_pass(); pipeline.register_pass(); + pipeline.register_pass(); pipeline.register_pass(); pipeline.register_pass(m_linear_ir->get_config().m_are_buffers_optimized); pipeline.register_pass(); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_wo_transpose.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_wo_transpose.cpp index 0967ef27087674..c6b11f48efa24c 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_wo_transpose.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_wo_transpose.cpp @@ -44,6 +44,11 @@ std::vector> originalShape_3D { {PartialShape{2, -1, 64}, {{2, 9, 64}, {2, 4, 64}, {2, 9, 64}}}, {PartialShape{2, 64, -1}, {{2, 64, 9}, {2, 64, 4}, {2, 64, 9}}}, {PartialShape{2, -1, 64}, {{2, 9, 64}, {2, 4, 64}, {2, 9, 64}}}, + }, + { + {PartialShape{-1, 128, 64}, {{1, 128, 64}, {2, 128, 64}, {1, 128, 64}}}, + {PartialShape{-1, 64, 128}, {{1, 64, 128}, {2, 64, 128}, {1, 64, 128}}}, + {PartialShape{-1, 128, 64}, {{1, 128, 64}, {2, 128, 64}, {1, 128, 64}}}, } };