diff --git a/src/core/src/validation_util.cpp b/src/core/src/validation_util.cpp index 6f5ae4e22cd6f0..50140902167164 100644 --- a/src/core/src/validation_util.cpp +++ b/src/core/src/validation_util.cpp @@ -151,7 +151,9 @@ namespace util { using ov::op::v0::Constant; std::shared_ptr get_constant_from_source(const ov::Output& source) { - if (const auto& c = ov::as_type_ptr(source.get_node_shared_ptr())) { + if (!source.get_node()) { + return {}; + } else if (const auto& c = ov::as_type_ptr(source.get_node_shared_ptr())) { return c; } else if (has_and_set_equal_bounds(source)) { return std::make_shared(source.get_tensor().get_upper_value()); diff --git a/src/core/tests/validation_utils.cpp b/src/core/tests/validation_utils.cpp index 1bc5452e7abf23..3fceeca7efdbb8 100644 --- a/src/core/tests/validation_utils.cpp +++ b/src/core/tests/validation_utils.cpp @@ -49,6 +49,11 @@ TEST(get_constant_from_source, extract_static_dim_from_dynamic_shape_check) { ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_upper_value()); } +TEST(get_constant_from_source, return_nullptr_for_empty_output) { + auto res = ov::util::get_constant_from_source(ov::Output()); + ASSERT_EQ(res, nullptr); +} + TEST(constantfold_subgraph, split) { std::vector input{0, 1, 2, 3, 4, 5, 6, 7, 8}; auto constant = ov::opset8::Constant::create(ov::element::f32, ov::Shape{input.size()}, input); diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index a5baa167db887f..5906043e51262d 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -9,7 +9,6 @@ #include "openvino/core/so_extension.hpp" #include "openvino/frontend/pytorch/extension/conversion.hpp" #include "openvino/op/util/multi_subgraph_base.hpp" -#include "openvino/pass/constant_folding.hpp" #include "openvino/util/common_util.hpp" #include "openvino/util/log.hpp" #include "place.hpp" @@ -19,7 +18,6 @@ #include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp" #include "transformations/control_flow/unroll_if.hpp" #include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp" -#include "transformations/low_precision/mark_dequantization_subgraph.hpp" #include "transformations/op_conversions/convert_convertlike.hpp" #include "transformations/op_conversions/convert_convertpromotetypes.hpp" #include "transformations/resolve_names_collisions.hpp" @@ -247,8 +245,6 @@ std::shared_ptr FrontEnd::decode(const InputModel::Ptr& model) const { } void FrontEnd::normalize(const std::shared_ptr& model) const { - ov::pass::Manager manager("Frontend:Pytorch:normalize"); - bool is_fx = false; if (model->has_rt_info("decoder_type_name")) { is_fx = model->get_rt_info()["decoder_type_name"].as() == "fx"; @@ -259,29 +255,55 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { // GPTQ transformations need to be executed before other passes // Once the GPTQ patterns are modified by other transformations, // they cannot be captured anymore + ov::pass::Manager manager("Frontend:Pytorch:normalize::fx_gptq"); manager.register_pass(); manager.register_pass(); + manager.run_passes(model); } - // the following 2 transformations are needed for keypoint detectron2 models to work. - // AtenIndexToSelect will be called twice - manager.register_pass(); - manager.register_pass(); + { + ov::pass::Manager manager("Frontend:Pytorch:normalize::no_val"); + // Passes replacing ops without relying on input shapes or types + manager.set_per_pass_validation(false); + // the following 2 transformations are needed for keypoint detectron2 models to work. + // AtenIndexToSelect will be called twice + manager.register_pass(); + manager.register_pass(); - // Mark quantized and f16/bf16 compressed constants to prevent CF for them, - // so that not extra memory is used for intermediate decompressed constants. - manager.register_pass( - element::TypeVector{element::u8, element::i8, element::u4, element::i4}); - manager.register_pass(); - manager.register_pass(); + // Mark quantized and f16/bf16 compressed constants to prevent CF for them, + // so that not extra memory is used for intermediate decompressed constants. + manager.register_pass(); + + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + bool is_changed = manager.run_passes(model); + + // make validation after previously non-validated passes + if (is_changed) + model->validate_nodes_and_infer_types(); + } - manager.register_pass(); - manager.register_pass(); + ov::pass::Manager manager("Frontend:Pytorch:normalize"); manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); @@ -289,19 +311,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.register_pass(); - manager.register_pass(); - manager.register_pass(); // Check if model is symmetrically quantized bool sym = false; @@ -311,15 +321,25 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { } manager.register_pass(sym); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); manager.register_pass(); - manager.register_pass(); - manager.register_pass(true); - manager.register_pass(); manager.run_passes(model); + { + ov::pass::Manager manager("Frontend:Pytorch:normalize::followup_no_val"); + manager.set_per_pass_validation(false); + + // ReverseShapeAndTypeInfer needs to run on validated model, it relies on shapes and types + manager.register_pass(); + // ConvertConvertLike will benefit from types inserted by ReverseShapeAndTypeInfer + manager.register_pass(); + manager.register_pass(true); + bool is_changed = manager.run_passes(model); + + // make validation after previously non-validated passes + if (is_changed) + model->validate_nodes_and_infer_types(); + } + // Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter // that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model // inference if model is completely frozen and all methods are inlined. So we check if it doesn't have any diff --git a/src/frontends/pytorch/src/op/expand.cpp b/src/frontends/pytorch/src/op/expand.cpp index b4ac055336daf9..a6bc239df96562 100644 --- a/src/frontends/pytorch/src/op/expand.cpp +++ b/src/frontends/pytorch/src/op/expand.cpp @@ -28,7 +28,7 @@ OutputVector translate_expand(const NodeContext& context) { // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) num_inputs_check(context, 2, 3); auto x = context.get_input(0); - auto sizes = context.get_input(1); + auto sizes = get_input_concat_if_list(context, 1); // TODO: figure out what implicit means PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input(2) == false, "Unexpected value of implicit for expand operation"); diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index ad0eb04527aa12..799c5d6feaebbe 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -57,7 +57,7 @@ Output base_translate_full_with_convert(const NodeContext& context, OutputVector translate_full(const NodeContext& context) { num_inputs_check(context, 2, 6); - auto sizes = context.get_input(0); + auto sizes = get_input_concat_if_list(context, 0); auto value = context.get_input(1); auto num_inputs = context.get_input_size(); if (num_inputs < 6) { @@ -136,7 +136,7 @@ OutputVector translate_fill(const NodeContext& context) { OutputVector translate_new_full(const NodeContext& context) { num_inputs_check(context, 3, 7); auto input = context.get_input(0); - auto sizes = context.get_input(1); + auto sizes = get_input_concat_if_list(context, 1); auto value = context.get_input(2); if (context.get_input_size() == 7 && !context.input_is_none(3)) { return {base_translate_full_with_convert(context, sizes, value, 3)}; @@ -161,7 +161,7 @@ OutputVector translate_new_full_fx(const NodeContext& context) { OutputVector translate_zeros(const NodeContext& context) { num_inputs_check(context, 2, 5); - auto sizes = context.get_input(0); + auto sizes = get_input_concat_if_list(context, 0); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); auto num_inputs = context.get_input_size(); if (num_inputs < 5) { @@ -218,7 +218,7 @@ OutputVector translate_zeros_like_fx(const NodeContext& context) { OutputVector translate_new_zeros(const NodeContext& context) { num_inputs_check(context, 2, 6); auto input = context.get_input(0); - auto sizes = context.get_input(1); + auto sizes = get_input_concat_if_list(context, 1); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); if (context.get_input_size() == 6 && !context.input_is_none(2)) { return {base_translate_full_with_convert(context, sizes, value, 2)}; @@ -243,7 +243,7 @@ OutputVector translate_new_zeros_fx(const NodeContext& context) { OutputVector translate_ones(const NodeContext& context) { num_inputs_check(context, 1, 5); - auto sizes = context.get_input(0); + auto sizes = get_input_concat_if_list(context, 0); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); auto num_inputs = context.get_input_size(); if (num_inputs < 5) { @@ -300,7 +300,7 @@ OutputVector translate_ones_like_fx(const NodeContext& context) { OutputVector translate_new_ones(const NodeContext& context) { num_inputs_check(context, 2, 6); auto input = context.get_input(0); - auto sizes = context.get_input(1); + auto sizes = get_input_concat_if_list(context, 1); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); if (context.get_input_size() == 6 && !context.input_is_none(2)) { return {base_translate_full_with_convert(context, sizes, value, 2)}; @@ -328,7 +328,7 @@ OutputVector translate_empty(const NodeContext& context) { // pin_memory=None, MemoryFormat? memory_format=None) -> Tensor layout, device and work with memory ignored on our // side, so just skip these parameters num_inputs_check(context, 1, 6); - auto sizes = context.get_input(0); + auto sizes = get_input_concat_if_list(context, 0); // In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type. auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); int dtype_id = 1; diff --git a/src/frontends/pytorch/src/op/reshape.cpp b/src/frontends/pytorch/src/op/reshape.cpp index 1a6b3008883c0c..a2c1a43a4fcb53 100644 --- a/src/frontends/pytorch/src/op/reshape.cpp +++ b/src/frontends/pytorch/src/op/reshape.cpp @@ -21,7 +21,8 @@ OutputVector translate_reshape(const NodeContext& context) { // Schema: aten::reshape(Tensor input, int[] shape) -> Tensor // For shape parameter, int[] is converted into single dimensional Tensor. num_inputs_check(context, 2, 2); - auto reshape = std::make_shared(context.get_input(0), context.get_input(1), false); + auto shape = get_input_concat_if_list(context, 1); + auto reshape = std::make_shared(context.get_input(0), shape, false); return {context.mark_node(reshape)}; }; diff --git a/src/frontends/pytorch/src/op/roll.cpp b/src/frontends/pytorch/src/op/roll.cpp index 3d037d22887778..5e3953ddf881f3 100644 --- a/src/frontends/pytorch/src/op/roll.cpp +++ b/src/frontends/pytorch/src/op/roll.cpp @@ -20,7 +20,7 @@ using namespace ov::op; OutputVector translate_roll(const NodeContext& context) { num_inputs_check(context, 2, 3); const auto data = context.get_input(0); - const auto shifts = context.get_input(1); + const auto shifts = get_input_concat_if_list(context, 1); Output axes; bool on_flattened = context.input_is_none(2); if (!on_flattened) { diff --git a/src/frontends/pytorch/src/op/slice.cpp b/src/frontends/pytorch/src/op/slice.cpp index f0d48a8075c901..6738f243c71ea5 100644 --- a/src/frontends/pytorch/src/op/slice.cpp +++ b/src/frontends/pytorch/src/op/slice.cpp @@ -6,6 +6,7 @@ #include +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/reshape.hpp" @@ -51,6 +52,9 @@ OutputVector translate_slice_common(const NodeContext& context, if (start.get_partial_shape().rank().is_dynamic() || start.get_partial_shape().rank().get_length() == 0) { start = context.mark_node(std::make_shared(start, dims_1d_shape, false)); } + if (const auto start_const = ov::util::get_constant_from_source(start)) { + start = start_const; + } } else { start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); } @@ -63,6 +67,9 @@ OutputVector translate_slice_common(const NodeContext& context, (!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0)) { end = context.mark_node(std::make_shared(end, dims_1d_shape, false)); } + if (const auto end_const = ov::util::get_constant_from_source(end)) { + end = end_const; + } } else { end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {INT_MAX})); } @@ -72,6 +79,9 @@ OutputVector translate_slice_common(const NodeContext& context, if (step.get_partial_shape().rank().is_dynamic() || step.get_partial_shape().rank().get_length() == 0) { step = context.mark_node(std::make_shared(step, dims_1d_shape, false)); } + if (const auto step_const = ov::util::get_constant_from_source(step)) { + step = step_const; + } } else { step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); } diff --git a/src/frontends/pytorch/src/op/transpose.cpp b/src/frontends/pytorch/src/op/transpose.cpp index 4ffb668d59821c..b3c87a994f7d97 100644 --- a/src/frontends/pytorch/src/op/transpose.cpp +++ b/src/frontends/pytorch/src/op/transpose.cpp @@ -4,8 +4,8 @@ #include "openvino/op/transpose.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" @@ -29,8 +29,9 @@ using namespace ov::op; OutputVector translate_transpose(const NodeContext& context) { num_inputs_check(context, 3, 3); + auto data = context.get_input(0); Output rank; - std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true); + std::tie(std::ignore, rank) = get_shape_rank(context, data, true); auto dim0_node = get_input_as_i32(context, 1); auto dim1_node = get_input_as_i32(context, 2); dim0_node = normalize_axis(context, dim0_node, rank); @@ -44,10 +45,15 @@ OutputVector translate_transpose(const NodeContext& context) { auto dim1_node_ = std::make_shared(dim1_node, axis_0); auto indices = std::make_shared(OutputVector{dim0_node_, dim1_node_}, 0); auto updates = std::make_shared(OutputVector{dim1_node_, dim0_node_}, 0); - auto scatter = std::make_shared(range, indices, updates, axis_0); - context.mark_nodes({start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter}); + Output scatter = std::make_shared(range, indices, updates, axis_0); + if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) { + scatter = context.mark_node(scatter_const); + } else { + context.mark_nodes( + {start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter.get_node_shared_ptr()}); + } - return {context.mark_node(std::make_shared(context.get_input(0), scatter))}; + return {context.mark_node(std::make_shared(data, scatter))}; }; OutputVector translate_t(const NodeContext& context) { diff --git a/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp index 6263e769639e6b..6e48c69126fd03 100644 --- a/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp @@ -8,6 +8,7 @@ #include #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/split.hpp" #include "openvino/op/squeeze.hpp" @@ -64,8 +65,9 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() { if (getitem_node) { // If aten::__getitem__, expect inputs to be equivalent of pytorch Tensor[][]. // Tensor selected by aten::__getitem__ index needs to be splitted in axis 0. - auto getitem_index_ptr = getitem_node->input_value(1).get_node_shared_ptr(); - auto getitem_index_const = std::dynamic_pointer_cast(getitem_index_ptr); + auto getitem_index_const = ov::util::get_constant_from_source(getitem_node->input_value(1)); + if (!getitem_index_const) + return false; auto index_val = getitem_index_const->cast_vector(); if (index_val.size() != 1) { add_exception_to_fw_node(list_unpack, "prim::ListUnpack: index of aten::__getitem__ is not scalar."); diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index 9cc20994d2cd41..1f31c75e6ae6c8 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -8,6 +8,7 @@ #include #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/loop.hpp" @@ -46,8 +47,7 @@ AtenCatToConcat::AtenCatToConcat() { int64_t axis; if (cat->get_input_size() > 1) { - auto axis_node = cat->get_input_node_shared_ptr(1); - auto axis_const = std::dynamic_pointer_cast(axis_node); + auto axis_const = ov::util::get_constant_from_source(cat->input_value(1)); if (!axis_const) { add_exception_to_fw_node(cat, "::cat unsupported case: axis is not a constant."); return false; diff --git a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp index 5d371e0d65f725..25be27e9939204 100644 --- a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp @@ -8,6 +8,7 @@ #include #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/op/add.hpp" #include "openvino/op/ceiling.hpp" #include "openvino/op/constant.hpp" @@ -94,8 +95,9 @@ AtenGetItemReplacer::AtenGetItemReplacer() { auto split = rg.make(input, split_slice_start, split_slice_end, const_1, axis_1d); replace_node(getitem, split); } else { - auto getitem_index_ptr = getitem->input_value(1).get_node_shared_ptr(); - auto getitem_index_const = std::dynamic_pointer_cast(getitem_index_ptr); + auto getitem_index_const = ov::util::get_constant_from_source(getitem->input_value(1)); + if (!getitem_index_const) + return false; auto split = rg.make(torch_split->get_input_source_output(0), torch_split->get_input_source_output(2), torch_split->get_input_source_output(1)); @@ -111,8 +113,8 @@ AtenGetItemReplacer::AtenGetItemReplacer() { getitem->output(0).replace(split->outputs()[index]); } } else if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) { - auto getitem_idx = getitem->input_value(1).get_node_shared_ptr(); - auto getitem_idx_const = std::dynamic_pointer_cast(getitem_idx); + auto getitem_idx = getitem->input_value(1); + auto getitem_idx_const = ov::util::get_constant_from_source(getitem_idx); if (getitem_idx_const) { auto idx = getitem_idx_const->cast_vector(); getitem->output(0).replace(list_construct->input_value(idx[0])); diff --git a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp index 4a658b34686673..587ff587bd333e 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp @@ -5,6 +5,7 @@ #include "aten_index_put_replacer.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/pytorch/visibility.hpp" #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" @@ -65,7 +66,7 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { auto input_shape = rg.make(input, element::i32); auto indices = index_op->input_value(1); auto values = index_op->input_value(2); - auto acc_const = std::dynamic_pointer_cast(index_op->input_value(3).get_node_shared_ptr()); + auto acc_const = ov::util::get_constant_from_source(index_op->input_value(3)); if (!acc_const) { add_exception_to_fw_node(index_op, "aten::index_put_: non constant accumulate input is not supported."); return false; diff --git a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp index f174cd623c71ee..4aae4d6e2a35dc 100644 --- a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp @@ -5,6 +5,7 @@ #include "index_loop_getitem_replacer.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/op/add.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" @@ -85,8 +86,8 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() { return false; } - auto dim = chunk_op->input_value(2); - if (!ov::as_type_ptr(dim.get_node_shared_ptr())) { + auto dim = ov::util::get_constant_from_source(chunk_op->input_value(2)); + if (!dim) { add_exception_to_fw_node(chunk_op, "aten::chunk: dimension is not constant."); return false; } @@ -120,7 +121,7 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() { // Add new inputs in Loop: chunk_size and dim_1d auto inp_descs = loop_op->get_input_descriptions(); auto chunks_size_body = rg.make(element::i32, Shape{1}); - auto dim_body = rg.make(dim.get_element_type(), Shape{1}); + auto dim_body = rg.make(dim->get_element_type(), Shape{1}); body->add_parameters({chunks_size_body, dim_body}); loop_op->set_argument(loop_op->get_input_size(), chunk_size); loop_op->set_argument(loop_op->get_input_size(), dim_1d); diff --git a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp index caa8d3cc18a3e6..b88cd48a1790a0 100644 --- a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp @@ -5,6 +5,7 @@ #include "listconstruct_replacer.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/op/abs.hpp" #include "openvino/op/adaptive_avg_pool.hpp" #include "openvino/op/broadcast.hpp" @@ -40,21 +41,11 @@ ListConstructReplacer::ListConstructReplacer() { // Transformation for torch operators for cases where prim::ListConstruct can be replaced with Concat. const auto& list = pattern::wrap_type(); - // Both aten::view and aten::reshape are using same translation returning Reshape operator. - const auto& reshape_op = pattern::wrap_type({pattern::any_input(), list}); - const auto& roll_op = pattern::wrap_type({pattern::any_input(), list, pattern::any_input()}); const auto& broadcast_op = pattern::wrap_type({pattern::any_input(), list}); - const auto& adapool_op = pattern::wrap_type({pattern::any_input(), list}); - // replace list construct for aten::expand(tensor, prim::ListConstruct(shapes)) old decomposition - // shape_of + broadcast + equal + select const auto& shape_of_op = pattern::wrap_type({list}); const auto& equal_op = pattern::wrap_type({list, pattern::any_input()}); const auto& select_op = pattern::wrap_type({pattern::any_input(), pattern::any_input(), list}); - // replace list construct for aten::expand(tensor, prim::ListConstruct(shapes)) new decomposition - const auto& abs_op = pattern::wrap_type({list}); - const auto& expand_op = pattern::wrap_type({pattern::any_input(), abs_op}); // replace list construct for aten::repeat(tensor, prim::ListConstruct(shapes))) - // shape_of + broadcast + tile const auto& tile_op = pattern::wrap_type({pattern::any_input(), list}); // replace aten::permute(tensor, prim::ListConstruct) const auto& transpose_op = pattern::wrap_type({pattern::any_input(), list}); @@ -67,14 +58,10 @@ ListConstructReplacer::ListConstructReplacer() { pattern::wrap_type({pattern::any_input(), interpolate_mul_op, pattern::any_input()}); // aten::randint case const auto& rand_op = pattern::wrap_type({list, pattern::any_input(), pattern::any_input()}); - const auto& lc_pattern = std::make_shared(OutputVector{reshape_op, - roll_op, - broadcast_op, - adapool_op, + const auto& lc_pattern = std::make_shared(OutputVector{broadcast_op, shape_of_op, equal_op, select_op, - expand_op, tile_op, transpose_op, vsplit_op, @@ -84,13 +71,13 @@ ListConstructReplacer::ListConstructReplacer() { ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { auto& pattern_map = m.get_pattern_value_map(); - auto list_node = pattern_map.at(list).get_node_shared_ptr(); + auto list_out = pattern_map.at(list); // Concatenation is possible because all elements in list should be scalar or 1D tensors, // result should be 1D tensor. OutputVector inputs; ov::pass::NodeRegistry rg; auto neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); - const auto& start_output = list_node->output(0); + const auto& start_output = list_out; for (const auto& input : get_list_as_outputs(start_output)) { if (input == start_output) { // Start output exist in list elements, it might mean we have only 1 element in list inputs and it is @@ -105,11 +92,15 @@ ListConstructReplacer::ListConstructReplacer() { } // reshape all elements to 1D auto reshape = rg.make(input, neg_1, false); - inputs.push_back(reshape); + if (const auto list_const = ov::util::get_constant_from_source(reshape)) { + inputs.push_back(list_const); + } else { + inputs.push_back(reshape); + } } auto concat = rg.make(inputs, 0); - copy_runtime_info_and_name(list_node, rg.get()); - replace_node(list_node, concat); + copy_runtime_info_and_name(list_out.get_node_shared_ptr(), rg.get()); + replace_node(list_out.get_node_shared_ptr(), concat); return true; }; auto m = std::make_shared(lc_pattern, "ov::frontend::pytorch::pass::ListConstructReplacer"); diff --git a/src/frontends/pytorch/src/transforms/u4_block_repack.cpp b/src/frontends/pytorch/src/transforms/u4_block_repack.cpp index 9fc4d2c8b104bb..675a293269002b 100644 --- a/src/frontends/pytorch/src/transforms/u4_block_repack.cpp +++ b/src/frontends/pytorch/src/transforms/u4_block_repack.cpp @@ -22,10 +22,10 @@ using namespace ov::op; using namespace ov::pass::pattern; U4BlockRepack::U4BlockRepack(bool is_symmetrical) { - const auto& m_constant = ov::pass::pattern::wrap_type(); - const auto& m_reshape1 = ov::pass::pattern::wrap_type({m_constant, any_input()}); - const auto& m_transpose = ov::pass::pattern::wrap_type({m_reshape1, any_input()}); - const auto& m_reshape2 = ov::pass::pattern::wrap_type({m_transpose, any_input()}); + const auto& m_constant = wrap_type(); + const auto& m_reshape1 = wrap_type({m_constant, any_input()}); + const auto& m_transpose = wrap_type({m_reshape1, any_input()}); + const auto& m_reshape2 = wrap_type({m_transpose, any_input()}); auto pack_byte = [](uint8_t lo, uint8_t hi) -> uint8_t { return (hi << 4) | (lo & 0x0F); @@ -43,8 +43,8 @@ U4BlockRepack::U4BlockRepack(bool is_symmetrical) { }; register_matcher( - std::make_shared(m_reshape2, "ov::frontend::pytorch::pass::U4BlockRepack"), - [=](ov::pass::pattern::Matcher& m) { + std::make_shared(m_reshape2, "ov::frontend::pytorch::pass::U4BlockRepack"), + [=](Matcher& m) { auto& pattern_to_output = m.get_pattern_value_map(); auto constant = std::dynamic_pointer_cast(pattern_to_output[m_constant].get_node_shared_ptr()); diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index f9a24ee739cff8..11e62baf2b606b 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -6,6 +6,7 @@ #include "op_table.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" @@ -149,8 +150,15 @@ std::shared_ptr get_node_axes_range(const NodeContext& context, const Outp }; Output normalize_axis(const NodeContext& context, const Output& axis, const Output& rank) { - auto axis_rank = context.mark_node(std::make_shared(axis, rank)); - return context.mark_node(std::make_shared(axis_rank, rank)); + auto axis_rank = std::make_shared(axis, rank); + auto new_axis = std::make_shared(axis_rank, rank); + + if (const auto axis_const = ov::util::get_constant_from_source(new_axis)) { + return context.mark_node(axis_const); + } else { + context.mark_nodes({axis_rank, new_axis}); + return new_axis; + } } std::shared_ptr numel(const NodeContext& context, const Output& x, element::Type output_type) { @@ -176,8 +184,8 @@ const std::unordered_map TORCH_TO_OV_TYPE{ {15, element::bf16}, }; -const std::unordered_map TORCH_AUTO_PAD_TO_OV{{"valid", ov::op::PadType::VALID}, - {"same", ov::op::PadType::SAME_UPPER}}; +const std::unordered_map TORCH_AUTO_PAD_TO_OV{{"valid", PadType::VALID}, + {"same", PadType::SAME_UPPER}}; } // namespace element::Type convert_dtype(int64_t pt_type) { @@ -200,7 +208,7 @@ Output apply_dtype(const NodeContext& context, size_t dtype_port, const Ou return input_tensor; }; -ov::op::PadType convert_pad(const std::string& pt_pad) { +PadType convert_pad(const std::string& pt_pad) { FRONT_END_OP_CONVERSION_CHECK(TORCH_AUTO_PAD_TO_OV.count(pt_pad), "Unknown pad: ", pt_pad); return TORCH_AUTO_PAD_TO_OV.at(pt_pad); }; @@ -371,7 +379,7 @@ std::shared_ptr cast_fw_node(std::shared_ptr } std::shared_ptr make_list_construct(const ov::OutputVector& inputs) { - auto list_construct = std::make_shared<::ov::op::util::FrameworkNode>(inputs, inputs.size()); + auto list_construct = std::make_shared(inputs, inputs.size()); ov::op::util::FrameworkNodeAttrs attrs; attrs.set_type_name("PTFrameworkNode"); attrs[PtFrameworkNode::op_type_key] = "prim::ListConstruct"; @@ -420,8 +428,8 @@ void align_eltwise_input_types(const NodeContext& context, const bool& is_rhs_python_scalar) { const auto& lhs_type = lhs.get_element_type(); const auto& rhs_type = rhs.get_element_type(); - auto const_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {1}); - auto const_1 = ov::op::v0::Constant::create(element::i32, Shape{1}, {1}); + auto const_0 = v0::Constant::create(element::i32, Shape{}, {1}); + auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1}); // Create temporary copy of lhs and rhs for ConvertPromoteTypes to not modify original nodes. ov::Output tmp_lhs = lhs; ov::Output tmp_rhs = rhs; @@ -436,8 +444,7 @@ void align_eltwise_input_types(const NodeContext& context, tmp_rhs = context.mark_node(std::make_shared(const_0, rhs)); } - auto at = context.mark_node( - std::make_shared(tmp_lhs, tmp_rhs, true, true, element::f32)); + auto at = context.mark_node(std::make_shared(tmp_lhs, tmp_rhs, true, true, element::f32)); auto dst_type = at->get_output_element_type(0); if (dst_type.is_dynamic()) { // Add ConvertLike on original node to not remove changes to shape done to differentiate between tensors and @@ -468,10 +475,18 @@ void align_output_types(const NodeContext& context, OutputVector& outputs) { } } +Output try_constfold(const Output& x) { + auto res = x; + if (const auto x_const = ov::util::get_constant_from_source(x)) { + res = x_const; + } + return res; +} + Output get_input_with_floating_type(const NodeContext& context, size_t idx) { auto x = context.get_input(static_cast(idx)); // This const only needed for type alignment - auto dummy_const = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape({}), {0.5}))->output(0); + auto dummy_const = context.mark_node(v0::Constant::create(element::f32, Shape({}), {0.5}))->output(0); align_eltwise_input_types(context, x, dummy_const, false, true); return x; } @@ -479,7 +494,29 @@ Output get_input_with_floating_type(const NodeContext& context, size_t idx Output get_input_as_i32(const NodeContext& context, size_t idx) { auto x = context.get_input(static_cast(idx)); if (x.get_element_type() != element::i32) { - x = context.mark_node(std::make_shared(x, element::i32)); + x = context.mark_node(std::make_shared(x, element::i32)); + } + return x; +} + +Output get_input_concat_if_list(const NodeContext& context, size_t idx) { + auto x = context.get_input(static_cast(idx)); + if (context.get_input_type(idx).is() && + std::dynamic_pointer_cast(x.get_node_shared_ptr())) { + auto elems = get_list_as_outputs(x, true); + if (elems.size() == 0) + // Can we figure real type for empty list? + return std::make_shared(element::i32, Shape{0}, std::vector{}); + OutputVector inputs; + for (auto& elem : elems) { + inputs.push_back(try_constfold(elem)); + } + auto new_x = std::make_shared(inputs, 0); + new_x->set_friendly_name(x.get_node_shared_ptr()->get_friendly_name()); + x = new_x; + } + if (const auto x_const = ov::util::get_constant_from_source(x)) { + return x_const; } return x; } @@ -499,9 +536,10 @@ std::tuple, Output> get_inputs_with_promoted_types(const Node return std::make_tuple(lhs, rhs); } -std::deque> get_list_as_outputs(const Output& start) { +std::deque> get_list_as_outputs(const Output& start, bool unsqueeze_for_concat) { std::deque> res; auto current_output = start; + auto zero = v0::Constant::create(element::i32, Shape{}, {0}); while (const auto& input_fw_node = std::dynamic_pointer_cast(current_output.get_node_shared_ptr())) { const auto& attrs = input_fw_node->get_attrs(); @@ -509,20 +547,28 @@ std::deque> get_list_as_outputs(const Output& start) { break; } if (attrs.at(PtFrameworkNode::op_type_key) == "aten::append") { - res.push_front(input_fw_node->input(1).get_source_output()); + auto elem = input_fw_node->get_input_source_output(1); + if (unsqueeze_for_concat) { + elem = std::make_shared(elem, zero); + } + res.push_front(elem); } else if (attrs.at(PtFrameworkNode::op_type_key) == "aten::add") { - const auto&& lhs_list = get_list_as_outputs(input_fw_node->input(1).get_source_output()); - res.insert(res.end(), lhs_list.begin(), lhs_list.end()); + const auto&& rhs_list = get_list_as_outputs(input_fw_node->get_input_source_output(1)); + res.insert(res.end(), rhs_list.begin(), rhs_list.end()); } else { break; } - current_output = input_fw_node->input(0).get_source_output(); + current_output = input_fw_node->get_input_source_output(0); } auto list_construct = cast_fw_node(current_output.get_node_shared_ptr(), "prim::ListConstruct"); if (list_construct) { auto inputs = list_construct->inputs(); for (auto input_it = inputs.rbegin(); input_it != inputs.rend(); ++input_it) { - res.push_front(input_it->get_source_output()); + auto elem = input_it->get_source_output(); + if (unsqueeze_for_concat) { + elem = std::make_shared(elem, zero); + } + res.push_front(elem); } } else { res.push_front(current_output); @@ -579,20 +625,20 @@ Output concat_list_from_inputs(const NodeContext& context, size_t begin, s auto const_val = context.const_input(i); std::vector dim_vec; dim_vec.push_back(const_val); - auto dim_const = ov::op::v0::Constant::create(element::i64, Shape{1}, dim_vec); + auto dim_const = v0::Constant::create(element::i64, Shape{1}, dim_vec); list_elems.push_back(dim_const); } else { auto input_dim = context.get_input(static_cast(i)); if (input_dim.get_partial_shape().rank() == 0) { - auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); - auto unsqueezed_dim = context.mark_node(std::make_shared(input_dim, zero)); + auto zero = v0::Constant::create(element::i32, Shape{}, {0}); + auto unsqueezed_dim = context.mark_node(std::make_shared(input_dim, zero)); list_elems.push_back(unsqueezed_dim); } else { list_elems.push_back(input_dim); } } } - auto concat = std::make_shared(list_elems, 0); + auto concat = std::make_shared(list_elems, 0); return concat; } diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 3d99573098a86c..434cc109d022aa 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -99,16 +99,20 @@ void align_eltwise_input_types(const NodeContext& context, const bool& ir_rhs_python_scalar = false); void align_output_types(const NodeContext& context, OutputVector& outputs); -std::deque> get_list_as_outputs(const Output& start); +std::deque> get_list_as_outputs(const Output& start, bool unsqueeze_for_concat = false); void copy_runtime_info_and_name(const std::shared_ptr& from, ov::NodeVector to, const ov::NodeVector& additional_rt_info_src = {}); +Output try_constfold(const Output& x); + Output get_input_with_floating_type(const NodeContext& context, size_t idx); Output get_input_as_i32(const NodeContext& context, size_t idx); +Output get_input_concat_if_list(const NodeContext& context, size_t idx); + std::tuple, Output> get_inputs_with_promoted_types(const NodeContext& context, size_t lhs_idx, size_t rhs_idx); diff --git a/src/frontends/pytorch/src/utils_quantize.cpp b/src/frontends/pytorch/src/utils_quantize.cpp index ccf7f20785b09b..e48c61314f4c0d 100644 --- a/src/frontends/pytorch/src/utils_quantize.cpp +++ b/src/frontends/pytorch/src/utils_quantize.cpp @@ -4,6 +4,7 @@ #include "utils_quantize.hpp" +#include "openvino/core/validation_util.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/bitwise_and.hpp" #include "openvino/op/broadcast.hpp" @@ -44,8 +45,10 @@ Output quantize_common(const NodeContext& context, const auto out_high_normalized = context.mark_node(std::make_shared(out_high, zero_point_convert)); - const auto bound_low = context.mark_node(std::make_shared(scale_convert, out_low_normalized)); - const auto bound_high = context.mark_node(std::make_shared(scale_convert, out_high_normalized)); + auto bound_low = + try_constfold(context.mark_node(std::make_shared(scale_convert, out_low_normalized))); + auto bound_high = + try_constfold(context.mark_node(std::make_shared(scale_convert, out_high_normalized))); const auto quantized_input = context.mark_node( std::make_shared(input_convert, bound_low, bound_high, bound_low, bound_high, levels)); @@ -60,7 +63,7 @@ Output quantize_common(const NodeContext& context, const auto input_convert = context.mark_node(std::make_shared(input, element::f32)); const auto scales_convert = context.mark_node(std::make_shared(scale, element::f32)); const auto zero_points_convert = context.mark_node(std::make_shared(zero_point, element::f32)); - const auto axis_convert = context.mark_node(std::make_shared(axis, element::i32)); + auto axis_convert = try_constfold(context.mark_node(std::make_shared(axis, element::i32))); const auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); @@ -84,8 +87,9 @@ Output quantize_common(const NodeContext& context, const auto out_low_normalized = context.mark_node(std::make_shared(out_low, zero_point_bc)); const auto out_high_normalized = context.mark_node(std::make_shared(out_high, zero_point_bc)); - const auto bound_low = context.mark_node(std::make_shared(scale_bc, out_low_normalized)); - const auto bound_high = context.mark_node(std::make_shared(scale_bc, out_high_normalized)); + auto bound_low = try_constfold(context.mark_node(std::make_shared(scale_bc, out_low_normalized))); + auto bound_high = + try_constfold(context.mark_node(std::make_shared(scale_bc, out_high_normalized))); const auto quantized_input = context.mark_node( std::make_shared(input_convert, bound_low, bound_high, bound_low, bound_high, levels)); @@ -247,7 +251,8 @@ std::shared_ptr u4_compression_stack(const OutputVector& list_elems, int64 if (axis != -1 && static_cast(axis) != weights_u8->get_shape().size() - 1) return nullptr; - if (!ov::op::util::has_constant_value(bitwise_and->get_input_node_shared_ptr(1), 0x0F)) + if (!ov::op::util::has_constant_value(ov::util::get_constant_from_source(bitwise_and->input_value(1)), + 0x0F)) return nullptr; if (!ov::op::util::has_constant_value(bitwise_shift->get_input_node_shared_ptr(1), 4)) diff --git a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py index 155b772d560222..faee72bb5d938a 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py @@ -432,7 +432,7 @@ def forward(self, x): converted_model = fe.convert(input_model) assert converted_model assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ - "Parameter", "Convert", "Convert", "Cos", "Constant", "Relu", "Multiply", "Add", "Result"] + "Parameter", "Convert", "Convert", "Cos", "Constant", "Convert", "Relu", "Multiply", "Add", "Result"] converted_model = convert_model(model, example_input=( torch.randn(100),), extension=[ModuleExtension(CosModel, "aten::sin"), ModuleExtension(model.relu_module, "aten::tan")])