From 8c36c0047377303b0406a74c45195cf29a460b2f Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Fri, 18 Oct 2024 21:57:50 +0400 Subject: [PATCH] [CORE] Skip unnecessary convert_to_supported_precision if ConstantFolding is omitted (#26756) Details: It's a modification of https://github.com/openvinotoolkit/openvino/pull/22674 f16 LLM (llama was tested) compilation time on ARM is unreasonable huge. Perf report shows that every ConstantFolding transformation takes several seconds even if the graph is not modified. The root cause is util::convert_to_supported_precision call even if constant folding is skipped. The suggested fix is to skip util::convert_to_supported_precision call if folding is not applied. Tickets: CVS-152428 --------- Co-authored-by: Aleksandr Voron Co-authored-by: Andrii Staikov --- src/core/include/openvino/core/node.hpp | 1 + src/core/include/openvino/op/assign.hpp | 2 +- src/core/include/openvino/op/constant.hpp | 2 +- src/core/include/openvino/op/convert_like.hpp | 1 + .../include/openvino/op/fake_quantize.hpp | 3 ++- .../include/openvino/op/random_uniform.hpp | 2 +- src/core/include/openvino/op/read_value.hpp | 2 +- src/core/include/openvino/op/reshape.hpp | 1 + src/core/include/openvino/op/result.hpp | 2 +- src/core/include/openvino/op/shape_of.hpp | 2 ++ src/core/include/openvino/op/squeeze.hpp | 1 + .../include/openvino/op/strided_slice.hpp | 1 + src/core/include/openvino/op/unsqueeze.hpp | 1 + .../include/openvino/op/util/gather_base.hpp | 1 + src/core/src/node.cpp | 14 ++++++++--- src/core/src/op/assign.cpp | 2 +- src/core/src/op/constant.cpp | 2 +- src/core/src/op/convert_like.cpp | 6 ++++- src/core/src/op/random_uniform.cpp | 2 +- src/core/src/op/read_value.cpp | 2 +- src/core/src/op/reshape.cpp | 6 ++++- src/core/src/op/result.cpp | 2 +- src/core/src/op/shape_of.cpp | 12 ++++++++-- src/core/src/op/squeeze.cpp | 6 ++++- src/core/src/op/strided_slice.cpp | 6 ++++- src/core/src/op/unsqueeze.cpp | 6 ++++- src/core/src/op/util/gather_base.cpp | 17 +++++++------ src/core/src/pass/constant_folding.cpp | 24 ++++++++++++------- 28 files changed, 90 insertions(+), 39 deletions(-) diff --git a/src/core/include/openvino/core/node.hpp b/src/core/include/openvino/core/node.hpp index f5a63911abc502..59a4ab29253ded 100644 --- a/src/core/include/openvino/core/node.hpp +++ b/src/core/include/openvino/core/node.hpp @@ -207,6 +207,7 @@ class OPENVINO_API Node : public std::enable_shared_from_this { virtual bool evaluate_upper(ov::TensorVector& output_values) const; virtual bool evaluate_symbol(TensorSymbolVector& output_symbols) const; + virtual bool can_constant_fold(const OutputVector& inputs_values) const; virtual bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values); /// \brief Decomposes the FusedOp into a sub-graph consisting of core openvino ops /// diff --git a/src/core/include/openvino/op/assign.hpp b/src/core/include/openvino/op/assign.hpp index c3f8492e54b4f8..895f6619778951 100644 --- a/src/core/include/openvino/op/assign.hpp +++ b/src/core/include/openvino/op/assign.hpp @@ -67,7 +67,7 @@ class OPENVINO_API Assign : public util::AssignBase { const TensorVector& inputs, const EvaluationContext& evaluation_context) const override; bool has_evaluate() const override; - bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; }; } // namespace v6 } // namespace op diff --git a/src/core/include/openvino/op/constant.hpp b/src/core/include/openvino/op/constant.hpp index 62b70a883fc1a5..ccaae01586d612 100644 --- a/src/core/include/openvino/op/constant.hpp +++ b/src/core/include/openvino/op/constant.hpp @@ -215,7 +215,7 @@ class OPENVINO_API Constant : public Op { bool evaluate_upper(TensorVector& outputs) const override; // Don't constant fold a constant; it would make a copy - bool constant_fold(OutputVector& outputs, const OutputVector& inputs) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; /// \brief Returns the value of the constant node as a Shape object /// Can only be used on element::i64 nodes and interprets diff --git a/src/core/include/openvino/op/convert_like.hpp b/src/core/include/openvino/op/convert_like.hpp index 244d0f4c7d70b4..0d7f73075e21b9 100644 --- a/src/core/include/openvino/op/convert_like.hpp +++ b/src/core/include/openvino/op/convert_like.hpp @@ -27,6 +27,7 @@ class OPENVINO_API ConvertLike : public Op { std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; }; } // namespace v1 } // namespace op diff --git a/src/core/include/openvino/op/fake_quantize.hpp b/src/core/include/openvino/op/fake_quantize.hpp index b47c7016c8709e..52caca885a02cc 100644 --- a/src/core/include/openvino/op/fake_quantize.hpp +++ b/src/core/include/openvino/op/fake_quantize.hpp @@ -69,7 +69,8 @@ class OPENVINO_API FakeQuantize : public Op { bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override; bool has_evaluate() const override; - bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override { + + bool can_constant_fold(const OutputVector& inputs_values) const override { return false; } diff --git a/src/core/include/openvino/op/random_uniform.hpp b/src/core/include/openvino/op/random_uniform.hpp index 6a4de83715e30a..22f06f79402135 100644 --- a/src/core/include/openvino/op/random_uniform.hpp +++ b/src/core/include/openvino/op/random_uniform.hpp @@ -42,7 +42,7 @@ class OPENVINO_API RandomUniform : public Op { std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; /// \return Turns off constant folding for RandomUniform operation. - bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; /// \return The output tensor type. const ov::element::Type& get_out_type() const; diff --git a/src/core/include/openvino/op/read_value.hpp b/src/core/include/openvino/op/read_value.hpp index 27447644037211..e37d6baa11c01c 100644 --- a/src/core/include/openvino/op/read_value.hpp +++ b/src/core/include/openvino/op/read_value.hpp @@ -80,7 +80,7 @@ class OPENVINO_API ReadValue : public util::ReadValueBase { const EvaluationContext& evaluation_context) const override; bool has_evaluate() const override; - bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; }; } // namespace v6 } // namespace op diff --git a/src/core/include/openvino/op/reshape.hpp b/src/core/include/openvino/op/reshape.hpp index f3a9e7aa8e59c1..48bc08f8c3d947 100644 --- a/src/core/include/openvino/op/reshape.hpp +++ b/src/core/include/openvino/op/reshape.hpp @@ -52,6 +52,7 @@ class OPENVINO_API Reshape : public Op { bool evaluate_lower(TensorVector& outputs) const override; bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; protected: bool m_special_zero; diff --git a/src/core/include/openvino/op/result.hpp b/src/core/include/openvino/op/result.hpp index dc8162a10b6627..00e805d1f2aeb5 100644 --- a/src/core/include/openvino/op/result.hpp +++ b/src/core/include/openvino/op/result.hpp @@ -30,7 +30,7 @@ class OPENVINO_API Result : public Op { bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; bool has_evaluate() const override; - bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; /// \brief Returns current layout, or empty Layout if it is not set Layout get_layout() const; diff --git a/src/core/include/openvino/op/shape_of.hpp b/src/core/include/openvino/op/shape_of.hpp index c8245d91069ed0..375d087f7e6cf8 100644 --- a/src/core/include/openvino/op/shape_of.hpp +++ b/src/core/include/openvino/op/shape_of.hpp @@ -38,6 +38,7 @@ class OPENVINO_API ShapeOf : public util::ShapeOfBase { bool evaluate_upper(TensorVector& output_values) const override; bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; private: element::Type m_output_type; @@ -64,6 +65,7 @@ class OPENVINO_API ShapeOf : public util::ShapeOfBase { bool evaluate_upper(TensorVector& output_values) const override; bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; }; } // namespace v0 } // namespace op diff --git a/src/core/include/openvino/op/squeeze.hpp b/src/core/include/openvino/op/squeeze.hpp index f7cb41f974db2f..8c27f29d66df66 100644 --- a/src/core/include/openvino/op/squeeze.hpp +++ b/src/core/include/openvino/op/squeeze.hpp @@ -27,6 +27,7 @@ class OPENVINO_API Squeeze : public Op { bool evaluate_upper(TensorVector& outputs) const override; bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; diff --git a/src/core/include/openvino/op/strided_slice.hpp b/src/core/include/openvino/op/strided_slice.hpp index 2ba4f84c0936bf..aa080bc6563b90 100644 --- a/src/core/include/openvino/op/strided_slice.hpp +++ b/src/core/include/openvino/op/strided_slice.hpp @@ -114,6 +114,7 @@ class OPENVINO_API StridedSlice : public Op { bool evaluate_upper(TensorVector& outputs) const override; bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; private: AxisSet convert_mask_to_axis_set(const std::vector& mask) const; diff --git a/src/core/include/openvino/op/unsqueeze.hpp b/src/core/include/openvino/op/unsqueeze.hpp index d9839c7d68d719..4701df2dd4d4ec 100644 --- a/src/core/include/openvino/op/unsqueeze.hpp +++ b/src/core/include/openvino/op/unsqueeze.hpp @@ -30,6 +30,7 @@ class OPENVINO_API Unsqueeze : public Op { bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; }; diff --git a/src/core/include/openvino/op/util/gather_base.hpp b/src/core/include/openvino/op/util/gather_base.hpp index f7846b83cfe465..9fa8387aee6b3a 100644 --- a/src/core/include/openvino/op/util/gather_base.hpp +++ b/src/core/include/openvino/op/util/gather_base.hpp @@ -34,6 +34,7 @@ class OPENVINO_API GatherBase : public Op { bool evaluate_symbol(TensorSymbolVector& output_symbols) const override; bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override; + bool can_constant_fold(const OutputVector& inputs_values) const override; const int64_t& get_batch_dims() const; void set_batch_dims(int64_t batch_dims); diff --git a/src/core/src/node.cpp b/src/core/src/node.cpp index 0341e4477f4cfb..8b9936b5496e7c 100644 --- a/src/core/src/node.cpp +++ b/src/core/src/node.cpp @@ -696,8 +696,8 @@ bool ov::Node::evaluate_symbol(TensorSymbolVector& output_symbols) const { return false; } -bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& input_values) { - OV_ITT_SCOPED_TASK(ov::itt::domains::core, "Node::constant_fold"); +bool ov::Node::can_constant_fold(const OutputVector& input_values) const { + OV_ITT_SCOPED_TASK(ov::itt::domains::core, "Node::can_constant_fold"); if (is_const_fold_disabled()) { return false; @@ -707,8 +707,16 @@ bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& in bool all_constants = std::all_of(input_values.begin(), input_values.end(), [](const Output& input) { return ov::as_type_ptr(input.get_node_shared_ptr()); }); - if (!all_constants) + + return all_constants; +} + +bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& input_values) { + OV_ITT_SCOPED_TASK(ov::itt::domains::core, "Node::constant_fold"); + + if (!Node::can_constant_fold(input_values)) { return false; + } NodeVector nodes; TensorVector input_tensors; diff --git a/src/core/src/op/assign.cpp b/src/core/src/op/assign.cpp index bf6e55c11b1d39..7798d4328049af 100644 --- a/src/core/src/op/assign.cpp +++ b/src/core/src/op/assign.cpp @@ -134,7 +134,7 @@ bool Assign::has_evaluate() const { return true; } -bool Assign::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { +bool Assign::can_constant_fold(const OutputVector& input_values) const { return false; } } // namespace v6 diff --git a/src/core/src/op/constant.cpp b/src/core/src/op/constant.cpp index 95df6379ba284e..e06718ef4e1fd5 100644 --- a/src/core/src/op/constant.cpp +++ b/src/core/src/op/constant.cpp @@ -663,7 +663,7 @@ bool Constant::evaluate_upper(TensorVector& outputs) const { return evaluate(outputs, {}); } -bool Constant::constant_fold(OutputVector&, const OutputVector&) { +bool Constant::can_constant_fold(const OutputVector& input_values) const { return false; } diff --git a/src/core/src/op/convert_like.cpp b/src/core/src/op/convert_like.cpp index 3dc0159bb556be..4ae4ea982f8cd9 100644 --- a/src/core/src/op/convert_like.cpp +++ b/src/core/src/op/convert_like.cpp @@ -29,9 +29,13 @@ std::shared_ptr ConvertLike::clone_with_new_inputs(const OutputVector& new return std::make_shared(new_args.at(0), new_args.at(1)); } +bool ConvertLike::can_constant_fold(const OutputVector& input_values) const { + return !is_const_fold_disabled(); +} + bool ConvertLike::constant_fold(OutputVector& output_values, const OutputVector& input_values) { OV_OP_SCOPE(v1_ConvertLike_constant_fold); - if (is_const_fold_disabled()) { + if (!can_constant_fold(input_values)) { return false; } diff --git a/src/core/src/op/random_uniform.cpp b/src/core/src/op/random_uniform.cpp index e62be4d26afc58..9aafed881086b6 100644 --- a/src/core/src/op/random_uniform.cpp +++ b/src/core/src/op/random_uniform.cpp @@ -88,7 +88,7 @@ std::shared_ptr RandomUniform::clone_with_new_inputs(const OutputVector& n } /// \return Turns off constant folding for RandomUniform operation. -bool RandomUniform::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { +bool RandomUniform::can_constant_fold(const OutputVector& input_values) const { return false; } diff --git a/src/core/src/op/read_value.cpp b/src/core/src/op/read_value.cpp index 162cb5067bc00a..0d63456a3b8348 100644 --- a/src/core/src/op/read_value.cpp +++ b/src/core/src/op/read_value.cpp @@ -176,7 +176,7 @@ bool ReadValue::has_evaluate() const { return true; } -bool ReadValue::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { +bool ReadValue::can_constant_fold(const OutputVector& input_values) const { return false; } } // namespace v6 diff --git a/src/core/src/op/reshape.cpp b/src/core/src/op/reshape.cpp index ab0e0a0c17cbde..477e210f574269 100644 --- a/src/core/src/op/reshape.cpp +++ b/src/core/src/op/reshape.cpp @@ -97,7 +97,7 @@ bool Reshape::evaluate_symbol(TensorSymbolVector& output_symbols) const { } bool Reshape::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { - if (get_output_partial_shape(0).is_dynamic() || is_const_fold_disabled()) { + if (!can_constant_fold(inputs_values)) { return false; } @@ -108,6 +108,10 @@ bool Reshape::constant_fold(OutputVector& output_values, const OutputVector& inp return false; } } + +bool Reshape::can_constant_fold(const OutputVector& input_values) const { + return get_output_partial_shape(0).is_static() && !is_const_fold_disabled(); +} } // namespace v1 } // namespace op } // namespace ov diff --git a/src/core/src/op/result.cpp b/src/core/src/op/result.cpp index 3667e5ff22b422..237d6bd7a2084a 100644 --- a/src/core/src/op/result.cpp +++ b/src/core/src/op/result.cpp @@ -67,7 +67,7 @@ bool Result::has_evaluate() const { return true; } -bool Result::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { +bool Result::can_constant_fold(const OutputVector& input_values) const { return false; } diff --git a/src/core/src/op/shape_of.cpp b/src/core/src/op/shape_of.cpp index 293c1b5fc5a59c..9676a5704ec99c 100644 --- a/src/core/src/op/shape_of.cpp +++ b/src/core/src/op/shape_of.cpp @@ -168,9 +168,13 @@ bool ShapeOf::evaluate_symbol(TensorSymbolVector& output_symbols) const { return shape_of::evaluate_symbol(this, output_symbols); } +bool ShapeOf::can_constant_fold(const OutputVector& input_values) const { + return !is_const_fold_disabled() && input_values[0].get_partial_shape().is_static(); +} + bool ShapeOf::constant_fold(OutputVector& output_values, const OutputVector& input_values) { OV_OP_SCOPE(v3_ShapeOf_constant_fold); - if (is_const_fold_disabled()) { + if (!can_constant_fold(input_values)) { return false; } return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]); @@ -222,9 +226,13 @@ bool ShapeOf::has_evaluate() const { } } +bool ShapeOf::can_constant_fold(const OutputVector& input_values) const { + return !is_const_fold_disabled() && input_values[0].get_partial_shape().is_static(); +} + bool ShapeOf::constant_fold(OutputVector& output_values, const OutputVector& input_values) { OV_OP_SCOPE(v0_ShapeOf_constant_fold); - if (is_const_fold_disabled()) { + if (!can_constant_fold(input_values)) { return false; } return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]); diff --git a/src/core/src/op/squeeze.cpp b/src/core/src/op/squeeze.cpp index 3abc0a773192d2..1b34a4e48a4faf 100644 --- a/src/core/src/op/squeeze.cpp +++ b/src/core/src/op/squeeze.cpp @@ -104,9 +104,13 @@ bool Squeeze::evaluate_symbol(TensorSymbolVector& output_symbols) const { return validate::axes_has_and_set_bound(*this) && ov::util::default_symbol_evaluator(this, output_symbols); } +bool Squeeze::can_constant_fold(const OutputVector& inputs_values) const { + return get_output_partial_shape(0).is_static() && !is_const_fold_disabled(); +} + bool Squeeze::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { OV_OP_SCOPE(v0_Squeeze_constant_fold); - if (get_output_partial_shape(0).is_dynamic() || is_const_fold_disabled()) { + if (!can_constant_fold(inputs_values)) { return false; } diff --git a/src/core/src/op/strided_slice.cpp b/src/core/src/op/strided_slice.cpp index deb89fa9a531d4..83ac3dec7a5f4f 100644 --- a/src/core/src/op/strided_slice.cpp +++ b/src/core/src/op/strided_slice.cpp @@ -283,9 +283,13 @@ bool StridedSlice::evaluate_symbol(TensorSymbolVector& output_symbols) const { default_symbol_evaluator(this, {0}, output_symbols); } +bool StridedSlice::can_constant_fold(const OutputVector& input_values) const { + return !is_const_fold_disabled(); +} + bool StridedSlice::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { auto is_folded = Node::constant_fold(output_values, inputs_values); - if (!is_const_fold_disabled() && !is_folded) { + if (can_constant_fold(inputs_values) && !is_folded) { // If all ignored mask are set for all begin or end then replace this input by dummy constant // to avoid return false from `could_propagate` during bound evaluation (value of const will be ignored). auto get_indices_input = [&inputs_values](size_t port, const std::vector& mask) -> Output { diff --git a/src/core/src/op/unsqueeze.cpp b/src/core/src/op/unsqueeze.cpp index d199c43a2479b5..f8c14a08f70d30 100644 --- a/src/core/src/op/unsqueeze.cpp +++ b/src/core/src/op/unsqueeze.cpp @@ -77,8 +77,12 @@ bool ov::op::v0::Unsqueeze::evaluate_symbol(TensorSymbolVector& output_symbols) return ov::util::default_symbol_evaluator(this, output_symbols); } +bool ov::op::v0::Unsqueeze::can_constant_fold(const OutputVector& input_values) const { + return get_output_partial_shape(0).is_static() && !is_const_fold_disabled(); +} + bool ov::op::v0::Unsqueeze::constant_fold(OutputVector& output_values, const OutputVector& inputs_values) { - if (get_output_partial_shape(0).is_dynamic() || is_const_fold_disabled()) { + if (!can_constant_fold(inputs_values)) { return false; } diff --git a/src/core/src/op/util/gather_base.cpp b/src/core/src/op/util/gather_base.cpp index 92e41781b1de55..dd35edf695ec16 100644 --- a/src/core/src/op/util/gather_base.cpp +++ b/src/core/src/op/util/gather_base.cpp @@ -32,10 +32,6 @@ Shape out_shape_infer(const Shape& data_shape, const Shape& indices_shape, int64 bool cf_gather_with_subgraph(OutputVector& output_values, const OutputVector& input_values, const PartialShape& gather_ps) { - if (gather_ps.is_dynamic() || input_values.size() != 3) { - return false; - } - const auto concat = std::dynamic_pointer_cast(input_values[0].get_node_shared_ptr()); const auto indices = std::dynamic_pointer_cast(input_values[1].get_node_shared_ptr()); const auto axis = std::dynamic_pointer_cast(input_values[2].get_node_shared_ptr()); @@ -67,7 +63,6 @@ bool cf_gather_with_subgraph(OutputVector& output_values, const auto raw_index = indices->cast_vector()[0]; const auto positive_index = ov::util::normalize(raw_index, rank); OPENVINO_ASSERT(positive_index >= 0 && positive_index < rank); - // gather takes exactly one element out of the Concat output const auto gathered_concat_input = concat_inputs[positive_index].get_source_output().get_node_shared_ptr(); // Concat inputs are 1D, resulting tensor shape depends on Gather indices @@ -77,9 +72,7 @@ bool cf_gather_with_subgraph(OutputVector& output_values, const auto axis_const = v0::Constant::create(element::i64, Shape{1}, {0}); gathered = std::make_shared(gathered_concat_input, axis_const); } - output_values[0] = gathered; - return true; } @@ -262,13 +255,19 @@ bool GatherBase::evaluate_symbol(TensorSymbolVector& output_symbols) const { return gather::have_indices_and_axis_bound_set(this) && ov::util::default_symbol_evaluator(this, output_symbols); } +bool GatherBase::can_constant_fold(const OutputVector& input_values) const { + return get_output_partial_shape(0).is_static() && input_values.size() == 3; +} + bool GatherBase::constant_fold(OutputVector& output_values, const OutputVector& input_values) { // try the regular constant folding just for the Gather node if (Node::constant_fold(output_values, input_values)) { return true; - } else { - return gather::cf_gather_with_subgraph(output_values, input_values, get_output_partial_shape(0)); } + if (!can_constant_fold(input_values)) { + return false; + } + return gather::cf_gather_with_subgraph(output_values, input_values, get_output_partial_shape(0)); } } // namespace util } // namespace op diff --git a/src/core/src/pass/constant_folding.cpp b/src/core/src/pass/constant_folding.cpp index 3de91829f91b0c..cc1a7cea5b5add 100644 --- a/src/core/src/pass/constant_folding.cpp +++ b/src/core/src/pass/constant_folding.cpp @@ -105,6 +105,21 @@ bool ov::pass::ConstantFolding::run_on_model(const std::shared_ptr& m for (const auto& original_node : model->get_ordered_ops()) { auto node = original_node; + if (!original_node->can_constant_fold(original_node->input_values())) { + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { + // recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop) + size_t sub_graphs_num = sub_graph_node->get_internal_subgraphs_size(); + for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) { + rewritten = + run_on_model(sub_graph_node->get_function(static_cast(sub_graph_ind))) || rewritten; + } + } + rewritten = restore_original_input_precision(original_node) || rewritten; + if (rewritten) { + original_node->validate_and_infer_types(); + } + continue; + } if (node_has_requires_precision_conversion_attribute(node)) { remove_requires_precision_conversion_attribute(node); node = util::convert_to_supported_precision(node.get()); @@ -143,15 +158,6 @@ bool ov::pass::ConstantFolding::run_on_model(const std::shared_ptr& m } } } else { - if (auto sub_graph_node = std::dynamic_pointer_cast(node)) { - // recursively constant fold operators containing subgraphs (ie: TensorIterator, Loop) - size_t sub_graphs_num = sub_graph_node->get_internal_subgraphs_size(); - for (size_t sub_graph_ind = 0; sub_graph_ind < sub_graphs_num; ++sub_graph_ind) { - rewritten = - run_on_model(sub_graph_node->get_function(static_cast(sub_graph_ind))) || rewritten; - } - } - // if CF was unsuccessful remove original precision attribute from inputs bool restored = restore_original_input_precision(original_node); if (restored) {