Skip to content

Commit

Permalink
[PT FE] Remove ConstantFolding pass from normalize (#26421)
Browse files Browse the repository at this point in the history
### Details:
 - *Remove ConstantFolding pass from normalize*
 - *Move `MarkCompressedFloatConstants` from frontends to MOC*

### Tickets:
 - *CVS-151253*

---------

Co-authored-by: Andrei Kochin <[email protected]>
  • Loading branch information
mvafin and andrei-kochin authored Sep 27, 2024
1 parent c023b94 commit 3dee83e
Show file tree
Hide file tree
Showing 20 changed files with 218 additions and 122 deletions.
4 changes: 3 additions & 1 deletion src/core/src/validation_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ namespace util {
using ov::op::v0::Constant;

std::shared_ptr<Constant> get_constant_from_source(const ov::Output<ov::Node>& source) {
if (const auto& c = ov::as_type_ptr<Constant>(source.get_node_shared_ptr())) {
if (!source.get_node()) {
return {};
} else if (const auto& c = ov::as_type_ptr<Constant>(source.get_node_shared_ptr())) {
return c;
} else if (has_and_set_equal_bounds(source)) {
return std::make_shared<Constant>(source.get_tensor().get_upper_value());
Expand Down
5 changes: 5 additions & 0 deletions src/core/tests/validation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ TEST(get_constant_from_source, extract_static_dim_from_dynamic_shape_check) {
ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_upper_value());
}

TEST(get_constant_from_source, return_nullptr_for_empty_output) {
auto res = ov::util::get_constant_from_source(ov::Output<ov::Node>());
ASSERT_EQ(res, nullptr);
}

TEST(constantfold_subgraph, split) {
std::vector<float> input{0, 1, 2, 3, 4, 5, 6, 7, 8};
auto constant = ov::opset8::Constant::create(ov::element::f32, ov::Shape{input.size()}, input);
Expand Down
96 changes: 58 additions & 38 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "openvino/core/so_extension.hpp"
#include "openvino/frontend/pytorch/extension/conversion.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "place.hpp"
Expand All @@ -19,7 +18,6 @@
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "transformations/control_flow/unroll_if.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_convertlike.hpp"
#include "transformations/op_conversions/convert_convertpromotetypes.hpp"
#include "transformations/resolve_names_collisions.hpp"
Expand Down Expand Up @@ -247,8 +245,6 @@ std::shared_ptr<Model> FrontEnd::decode(const InputModel::Ptr& model) const {
}

void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
ov::pass::Manager manager("Frontend:Pytorch:normalize");

bool is_fx = false;
if (model->has_rt_info("decoder_type_name")) {
is_fx = model->get_rt_info()["decoder_type_name"].as<std::string>() == "fx";
Expand All @@ -259,49 +255,63 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
// GPTQ transformations need to be executed before other passes
// Once the GPTQ patterns are modified by other transformations,
// they cannot be captured anymore
ov::pass::Manager manager("Frontend:Pytorch:normalize::fx_gptq");
manager.register_pass<ov::frontend::pytorch::pass::GPTQDecompressionReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::GPTQMultPatternReplacer>();
manager.run_passes(model);
}

// the following 2 transformations are needed for keypoint detectron2 models to work.
// AtenIndexToSelect will be called twice
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
{
ov::pass::Manager manager("Frontend:Pytorch:normalize::no_val");
// Passes replacing ops without relying on input shapes or types
manager.set_per_pass_validation(false);
// the following 2 transformations are needed for keypoint detectron2 models to work.
// AtenIndexToSelect will be called twice
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();

// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
// so that not extra memory is used for intermediate decompressed constants.
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
element::TypeVector{element::u8, element::i8, element::u4, element::i4});
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
manager.register_pass<ov::pass::ConstantFolding>();
// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
// so that not extra memory is used for intermediate decompressed constants.
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();

manager.register_pass<ov::pass::ConvertConvertPromoteTypes>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenStackListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeUnpackParameters>();
manager.register_pass<ov::frontend::pytorch::pass::DictParameterResolver>();
manager.register_pass<ov::frontend::pytorch::pass::DictResultResolver>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
manager.register_pass<ov::frontend::pytorch::pass::ReversepropResolver>();
manager.register_pass<ov::frontend::pytorch::pass::MovePackThroughLstm>();
manager.register_pass<ov::frontend::pytorch::pass::RemovePackingOps>();
bool is_changed = manager.run_passes(model);

// make validation after previously non-validated passes
if (is_changed)
model->validate_nodes_and_infer_types();
}

manager.register_pass<ov::pass::ConvertConvertPromoteTypes>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
ov::pass::Manager manager("Frontend:Pytorch:normalize");
manager.register_pass<ov::pass::UnrollIf>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenStackListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
// TODO: remove AtenIndexToSelect when problem with dynamic input rank is gone.
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexPutReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeUnpackParameters>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DictParameterResolver>();
manager.register_pass<ov::frontend::pytorch::pass::DictResultResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();

// Check if model is symmetrically quantized
bool sym = false;
Expand All @@ -311,15 +321,25 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
}
manager.register_pass<ov::frontend::pytorch::pass::U4BlockRepack>(sym);

manager.register_pass<ov::frontend::pytorch::pass::ReversepropResolver>();
manager.register_pass<ov::frontend::pytorch::pass::MovePackThroughLstm>();
manager.register_pass<ov::frontend::pytorch::pass::RemovePackingOps>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.register_pass<ov::pass::ResolveNameCollisions>(true);
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.run_passes(model);

{
ov::pass::Manager manager("Frontend:Pytorch:normalize::followup_no_val");
manager.set_per_pass_validation(false);

// ReverseShapeAndTypeInfer needs to run on validated model, it relies on shapes and types
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
// ConvertConvertLike will benefit from types inserted by ReverseShapeAndTypeInfer
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::pass::ResolveNameCollisions>(true);
bool is_changed = manager.run_passes(model);

// make validation after previously non-validated passes
if (is_changed)
model->validate_nodes_and_infer_types();
}

// Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter
// that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model
// inference if model is completely frozen and all methods are inlined. So we check if it doesn't have any
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ OutputVector translate_expand(const NodeContext& context) {
// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
Expand Down
14 changes: 7 additions & 7 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Output<Node> base_translate_full_with_convert(const NodeContext& context,

OutputVector translate_full(const NodeContext& context) {
num_inputs_check(context, 2, 6);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.get_input(1);
auto num_inputs = context.get_input_size();
if (num_inputs < 6) {
Expand Down Expand Up @@ -136,7 +136,7 @@ OutputVector translate_fill(const NodeContext& context) {
OutputVector translate_new_full(const NodeContext& context) {
num_inputs_check(context, 3, 7);
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
auto value = context.get_input(2);
if (context.get_input_size() == 7 && !context.input_is_none(3)) {
return {base_translate_full_with_convert(context, sizes, value, 3)};
Expand All @@ -161,7 +161,7 @@ OutputVector translate_new_full_fx(const NodeContext& context) {

OutputVector translate_zeros(const NodeContext& context) {
num_inputs_check(context, 2, 5);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
auto num_inputs = context.get_input_size();
if (num_inputs < 5) {
Expand Down Expand Up @@ -218,7 +218,7 @@ OutputVector translate_zeros_like_fx(const NodeContext& context) {
OutputVector translate_new_zeros(const NodeContext& context) {
num_inputs_check(context, 2, 6);
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
Expand All @@ -243,7 +243,7 @@ OutputVector translate_new_zeros_fx(const NodeContext& context) {

OutputVector translate_ones(const NodeContext& context) {
num_inputs_check(context, 1, 5);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
auto num_inputs = context.get_input_size();
if (num_inputs < 5) {
Expand Down Expand Up @@ -300,7 +300,7 @@ OutputVector translate_ones_like_fx(const NodeContext& context) {
OutputVector translate_new_ones(const NodeContext& context) {
num_inputs_check(context, 2, 6);
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
Expand Down Expand Up @@ -328,7 +328,7 @@ OutputVector translate_empty(const NodeContext& context) {
// pin_memory=None, MemoryFormat? memory_format=None) -> Tensor layout, device and work with memory ignored on our
// side, so just skip these parameters
num_inputs_check(context, 1, 6);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
// In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type.
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
int dtype_id = 1;
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ OutputVector translate_reshape(const NodeContext& context) {
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
// For shape parameter, int[] is converted into single dimensional Tensor.
num_inputs_check(context, 2, 2);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), context.get_input(1), false);
auto shape = get_input_concat_if_list(context, 1);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), shape, false);
return {context.mark_node(reshape)};
};

Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/roll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace ov::op;
OutputVector translate_roll(const NodeContext& context) {
num_inputs_check(context, 2, 3);
const auto data = context.get_input(0);
const auto shifts = context.get_input(1);
const auto shifts = get_input_concat_if_list(context, 1);
Output<Node> axes;
bool on_flattened = context.input_is_none(2);
if (!on_flattened) {
Expand Down
10 changes: 10 additions & 0 deletions src/frontends/pytorch/src/op/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <climits>

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
Expand Down Expand Up @@ -51,6 +52,9 @@ OutputVector translate_slice_common(const NodeContext& context,
if (start.get_partial_shape().rank().is_dynamic() || start.get_partial_shape().rank().get_length() == 0) {
start = context.mark_node(std::make_shared<v1::Reshape>(start, dims_1d_shape, false));
}
if (const auto start_const = ov::util::get_constant_from_source(start)) {
start = start_const;
}
} else {
start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
}
Expand All @@ -63,6 +67,9 @@ OutputVector translate_slice_common(const NodeContext& context,
(!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0)) {
end = context.mark_node(std::make_shared<v1::Reshape>(end, dims_1d_shape, false));
}
if (const auto end_const = ov::util::get_constant_from_source(end)) {
end = end_const;
}
} else {
end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {INT_MAX}));
}
Expand All @@ -72,6 +79,9 @@ OutputVector translate_slice_common(const NodeContext& context,
if (step.get_partial_shape().rank().is_dynamic() || step.get_partial_shape().rank().get_length() == 0) {
step = context.mark_node(std::make_shared<v1::Reshape>(step, dims_1d_shape, false));
}
if (const auto step_const = ov::util::get_constant_from_source(step)) {
step = step_const;
}
} else {
step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
}
Expand Down
16 changes: 11 additions & 5 deletions src/frontends/pytorch/src/op/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include "openvino/op/transpose.hpp"

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
Expand All @@ -29,8 +29,9 @@ using namespace ov::op;

OutputVector translate_transpose(const NodeContext& context) {
num_inputs_check(context, 3, 3);
auto data = context.get_input(0);
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
std::tie(std::ignore, rank) = get_shape_rank(context, data, true);
auto dim0_node = get_input_as_i32(context, 1);
auto dim1_node = get_input_as_i32(context, 2);
dim0_node = normalize_axis(context, dim0_node, rank);
Expand All @@ -44,10 +45,15 @@ OutputVector translate_transpose(const NodeContext& context) {
auto dim1_node_ = std::make_shared<v0::Unsqueeze>(dim1_node, axis_0);
auto indices = std::make_shared<v0::Concat>(OutputVector{dim0_node_, dim1_node_}, 0);
auto updates = std::make_shared<v0::Concat>(OutputVector{dim1_node_, dim0_node_}, 0);
auto scatter = std::make_shared<v3::ScatterElementsUpdate>(range, indices, updates, axis_0);
context.mark_nodes({start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter});
Output<Node> scatter = std::make_shared<v3::ScatterElementsUpdate>(range, indices, updates, axis_0);
if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) {
scatter = context.mark_node(scatter_const);
} else {
context.mark_nodes(
{start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter.get_node_shared_ptr()});
}

return {context.mark_node(std::make_shared<v1::Transpose>(context.get_input(0), scatter))};
return {context.mark_node(std::make_shared<v1::Transpose>(data, scatter))};
};

OutputVector translate_t(const NodeContext& context) {
Expand Down
Loading

0 comments on commit 3dee83e

Please sign in to comment.