Skip to content

Commit

Permalink
[Transformations] Make ov::ModelPass transformations execute recursively
Browse files Browse the repository at this point in the history
Some ov::ModelPass transformations lack recursive execution for
subgraphs leaving it not processed.
Add the required recursive call for MultiSubGraphOp operations.

Ticket: CVS-133627
Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake committed Feb 26, 2024
1 parent a5f6308 commit 2735038
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ class ov::pass::low_precision::PropagateSharedValue : public ov::pass::ModelPass
std::vector<std::shared_ptr<ov::Node>> nodes(f->get_ordered_ops());
for (auto it = nodes.begin(); it != nodes.end(); it++) {
const std::shared_ptr<Node> node = *it;
if (const auto& multi_subgrah_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgrah_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}

if (ov::is_type<opset1::FakeQuantize>(node)) {
assert(node->get_output_size() == 1ul);
auto& outputRtInfo = node->output(0).get_rt_info();
Expand Down
8 changes: 8 additions & 0 deletions src/common/snippets/src/pass/propagate_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ bool ov::snippets::pass::PropagatePrecision::run_on_model(const std::shared_ptr<

bool was_updated = false;
for (const auto& op : f->get_ordered_ops()) {
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<ov::op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}

auto type_info = op->get_type_info();
std::set<ov::element::TypeVector> supported_precisions;
// TODO: At the moment Softmax is decomposed on Linear IR level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@

#include "transformations/common_optimizations/fused_names_cleanup.hpp"

#include <memory>

#include "openvino/cc/pass/itt.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/utils/utils.hpp"

bool ov::pass::FusedNamesCleanup::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(FusedNamesCleanup);

for (auto& node : f->get_ordered_ops()) {
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}

RTMap& rt_info = node->get_rt_info();
auto it = rt_info.find(ov::FusedNames::get_type_info_static());
if (it != rt_info.end()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ bool ov::pass::UselessSliceEraser::run_on_model(const std::shared_ptr<ov::Model>
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
rewritten |= run_on_model(sub_graph);
if (const auto& multi_subgraph_node = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_node->get_functions()) {
if (sub_graph) {
rewritten |= run_on_model(sub_graph);
}
}
}
bool is_slice = ov::is_type<ov::op::v1::StridedSlice>(node) || ov::is_type<ov::op::v8::Slice>(node);
Expand Down Expand Up @@ -102,9 +104,11 @@ bool ov::pass::GroupedStridedSliceOptimizer::run_on_model(const std::shared_ptr<
std::map<ov::Output<Node>, std::vector<planned_slice>> source_to_ss_with_plan;
for (const auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
graph_rewritten |= run_on_model(sub_graph);
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
graph_rewritten |= run_on_model(sub_graph);
}
}
}
if (auto ss = std::dynamic_pointer_cast<ov::op::v1::StridedSlice>(node)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ bool ov::pass::ReverseShapeAndTypeInfer::run_on_model(const std::shared_ptr<ov::
auto ops = f->get_ordered_ops();
for (auto it = ops.rbegin(); it != ops.rend(); ++it) {
const auto& op = *it;
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
is_changed |= run_on_model(sub_graph);
}
}
}

auto output_shape = op->get_output_partial_shape(0);
auto output_type = op->get_output_element_type(0);
if (const auto& param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(op)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
bool ov::pass::UnrollTensorIterator::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(UnrollTensorIterator);
for (const auto& op : f->get_ops()) {
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}

auto sub_graph_op = std::dynamic_pointer_cast<op::util::SubGraphOp>(op);
if (!sub_graph_op || transformation_callback(sub_graph_op)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ bool ov::pass::InitNodeInfo::run_on_model(const std::shared_ptr<ov::Model>& f) {

for (auto& node : f->get_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) {
run_on_model(sub_graph);
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}
auto& rtInfo = node->get_rt_info();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,13 @@ bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr<ov::Model>& f)
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (const auto& sub_graph_node = dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph = sub_graph_node->get_function())
rewritten |= run_on_model(sub_graph);
if (const auto& multi_subgraph_op = dynamic_pointer_cast<ov::op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
rewritten |= run_on_model(sub_graph);
}
}
}

// Case without TI (LSTMCell and Constant are in the same ov::Model)
if (const auto& lstm_cell = dynamic_pointer_cast<ov::op::v4::LSTMCell>(node))
Expand Down
8 changes: 8 additions & 0 deletions src/core/src/pass/low_latency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ bool ov::pass::LowLatency2::run_on_model(const std::shared_ptr<Model>& f) {

ov::SinkVector assigns;
for (const auto& op : f->get_ordered_ops()) {
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}

if (const auto& sub_graph_op = std::dynamic_pointer_cast<SubGraphOp>(op)) {
int64_t variable_id = 0;
const auto& func = sub_graph_op->get_function();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr<ov::Model>& f) {

bool was_updated = false;
for (const auto& op : f->get_ordered_ops()) {
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(op)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}

const auto& precisions = get_supported_precisions(op);

if (precisions.empty()) {
Expand Down
10 changes: 10 additions & 0 deletions src/plugins/intel_cpu/src/utils/print_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ class OPENVINO_API PrintModel : public ov::pass::ModelPass {
if (m_file_name.empty())
return false;

for (auto& node : model->get_ordered_ops()) {
if (const auto& multi_subgraph_op = std::dynamic_pointer_cast<op::util::MultiSubGraphOp>(node)) {
for (const auto& sub_graph : multi_subgraph_op->get_functions()) {
if (sub_graph) {
run_on_model(sub_graph);
}
}
}
}

std::ofstream ofs(m_file_name);
if (!ofs) {
// OPENVINO_WARN << "Error opening file " << m_file_name << " for output" << std::endl;
Expand Down

0 comments on commit 2735038

Please sign in to comment.