Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Remove ConstantFolding pass from normalize #26421

Merged
merged 21 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/core/src/validation_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ namespace util {
using ov::op::v0::Constant;

std::shared_ptr<Constant> get_constant_from_source(const ov::Output<ov::Node>& source) {
if (const auto& c = ov::as_type_ptr<Constant>(source.get_node_shared_ptr())) {
if (!source.get_node()) {
return {};
} else if (const auto& c = ov::as_type_ptr<Constant>(source.get_node_shared_ptr())) {
return c;
} else if (has_and_set_equal_bounds(source)) {
return std::make_shared<Constant>(source.get_tensor().get_upper_value());
Expand Down
5 changes: 5 additions & 0 deletions src/core/tests/validation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ TEST(get_constant_from_source, extract_static_dim_from_dynamic_shape_check) {
ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_upper_value());
}

TEST(get_constant_from_source, return_nullptr_for_empty_output) {
auto res = ov::util::get_constant_from_source(ov::Output<ov::Node>());
ASSERT_EQ(res, nullptr);
}

TEST(constantfold_subgraph, split) {
std::vector<float> input{0, 1, 2, 3, 4, 5, 6, 7, 8};
auto constant = ov::opset8::Constant::create(ov::element::f32, ov::Shape{input.size()}, input);
Expand Down
96 changes: 58 additions & 38 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "openvino/core/so_extension.hpp"
#include "openvino/frontend/pytorch/extension/conversion.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "openvino/util/common_util.hpp"
#include "openvino/util/log.hpp"
#include "place.hpp"
Expand All @@ -19,7 +18,6 @@
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
#include "transformations/control_flow/unroll_if.hpp"
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
#include "transformations/op_conversions/convert_convertlike.hpp"
#include "transformations/op_conversions/convert_convertpromotetypes.hpp"
#include "transformations/resolve_names_collisions.hpp"
Expand Down Expand Up @@ -247,8 +245,6 @@ std::shared_ptr<Model> FrontEnd::decode(const InputModel::Ptr& model) const {
}

void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
ov::pass::Manager manager("Frontend:Pytorch:normalize");

bool is_fx = false;
if (model->has_rt_info("decoder_type_name")) {
is_fx = model->get_rt_info()["decoder_type_name"].as<std::string>() == "fx";
Expand All @@ -259,49 +255,63 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
// GPTQ transformations need to be executed before other passes
// Once the GPTQ patterns are modified by other transformations,
// they cannot be captured anymore
ov::pass::Manager manager("Frontend:Pytorch:normalize::fx_gptq");
manager.register_pass<ov::frontend::pytorch::pass::GPTQDecompressionReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::GPTQMultPatternReplacer>();
manager.run_passes(model);
}

// the following 2 transformations are needed for keypoint detectron2 models to work.
// AtenIndexToSelect will be called twice
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
{
ov::pass::Manager manager("Frontend:Pytorch:normalize::no_val");
// Passes replacing ops without relying on input shapes or types
manager.set_per_pass_validation(false);
// the following 2 transformations are needed for keypoint detectron2 models to work.
// AtenIndexToSelect will be called twice
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();

// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
// so that not extra memory is used for intermediate decompressed constants.
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
element::TypeVector{element::u8, element::i8, element::u4, element::i4});
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
manager.register_pass<ov::pass::ConstantFolding>();
// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
// so that not extra memory is used for intermediate decompressed constants.
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();

manager.register_pass<ov::pass::ConvertConvertPromoteTypes>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenStackListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeUnpackParameters>();
manager.register_pass<ov::frontend::pytorch::pass::DictParameterResolver>();
manager.register_pass<ov::frontend::pytorch::pass::DictResultResolver>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();
manager.register_pass<ov::frontend::pytorch::pass::ReversepropResolver>();
manager.register_pass<ov::frontend::pytorch::pass::MovePackThroughLstm>();
manager.register_pass<ov::frontend::pytorch::pass::RemovePackingOps>();
bool is_changed = manager.run_passes(model);

// make validation after previously non-validated passes
if (is_changed)
model->validate_nodes_and_infer_types();
}

manager.register_pass<ov::pass::ConvertConvertPromoteTypes>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
ov::pass::Manager manager("Frontend:Pytorch:normalize");
manager.register_pass<ov::pass::UnrollIf>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenCatToConcat>();
manager.register_pass<ov::frontend::pytorch::pass::AppendListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenStackListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
// TODO: remove AtenIndexToSelect when problem with dynamic input rank is gone.
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexPutReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeUnpackParameters>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DictParameterResolver>();
manager.register_pass<ov::frontend::pytorch::pass::DictResultResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::QuantizedNodeRemover>();
manager.register_pass<ov::frontend::pytorch::pass::SoftmaxReshapeElimination>();

// Check if model is symmetrically quantized
bool sym = false;
Expand All @@ -311,15 +321,25 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
}
manager.register_pass<ov::frontend::pytorch::pass::U4BlockRepack>(sym);

manager.register_pass<ov::frontend::pytorch::pass::ReversepropResolver>();
manager.register_pass<ov::frontend::pytorch::pass::MovePackThroughLstm>();
manager.register_pass<ov::frontend::pytorch::pass::RemovePackingOps>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.register_pass<ov::pass::ResolveNameCollisions>(true);
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.run_passes(model);

{
ov::pass::Manager manager("Frontend:Pytorch:normalize::followup_no_val");
manager.set_per_pass_validation(false);

// ReverseShapeAndTypeInfer needs to run on validated model, it relies on shapes and types
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
// ConvertConvertLike will benefit from types inserted by ReverseShapeAndTypeInfer
manager.register_pass<ov::pass::ConvertConvertLike>();
manager.register_pass<ov::pass::ResolveNameCollisions>(true);
bool is_changed = manager.run_passes(model);

// make validation after previously non-validated passes
if (is_changed)
model->validate_nodes_and_infer_types();
}

// Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter
// that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model
// inference if model is completely frozen and all methods are inlined. So we check if it doesn't have any
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ OutputVector translate_expand(const NodeContext& context) {
// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
// TODO: figure out what implicit means
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
"Unexpected value of implicit for expand operation");
Expand Down
14 changes: 7 additions & 7 deletions src/frontends/pytorch/src/op/full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Output<Node> base_translate_full_with_convert(const NodeContext& context,

OutputVector translate_full(const NodeContext& context) {
num_inputs_check(context, 2, 6);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.get_input(1);
auto num_inputs = context.get_input_size();
if (num_inputs < 6) {
Expand Down Expand Up @@ -136,7 +136,7 @@ OutputVector translate_fill(const NodeContext& context) {
OutputVector translate_new_full(const NodeContext& context) {
num_inputs_check(context, 3, 7);
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
auto value = context.get_input(2);
if (context.get_input_size() == 7 && !context.input_is_none(3)) {
return {base_translate_full_with_convert(context, sizes, value, 3)};
Expand All @@ -161,7 +161,7 @@ OutputVector translate_new_full_fx(const NodeContext& context) {

OutputVector translate_zeros(const NodeContext& context) {
num_inputs_check(context, 2, 5);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
auto num_inputs = context.get_input_size();
if (num_inputs < 5) {
Expand Down Expand Up @@ -218,7 +218,7 @@ OutputVector translate_zeros_like_fx(const NodeContext& context) {
OutputVector translate_new_zeros(const NodeContext& context) {
num_inputs_check(context, 2, 6);
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
Expand All @@ -243,7 +243,7 @@ OutputVector translate_new_zeros_fx(const NodeContext& context) {

OutputVector translate_ones(const NodeContext& context) {
num_inputs_check(context, 1, 5);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
auto num_inputs = context.get_input_size();
if (num_inputs < 5) {
Expand Down Expand Up @@ -300,7 +300,7 @@ OutputVector translate_ones_like_fx(const NodeContext& context) {
OutputVector translate_new_ones(const NodeContext& context) {
num_inputs_check(context, 2, 6);
auto input = context.get_input(0);
auto sizes = context.get_input(1);
auto sizes = get_input_concat_if_list(context, 1);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
return {base_translate_full_with_convert(context, sizes, value, 2)};
Expand Down Expand Up @@ -328,7 +328,7 @@ OutputVector translate_empty(const NodeContext& context) {
// pin_memory=None, MemoryFormat? memory_format=None) -> Tensor layout, device and work with memory ignored on our
// side, so just skip these parameters
num_inputs_check(context, 1, 6);
auto sizes = context.get_input(0);
auto sizes = get_input_concat_if_list(context, 0);
// In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type.
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
int dtype_id = 1;
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/pytorch/src/op/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ OutputVector translate_reshape(const NodeContext& context) {
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
// For shape parameter, int[] is converted into single dimensional Tensor.
num_inputs_check(context, 2, 2);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), context.get_input(1), false);
auto shape = get_input_concat_if_list(context, 1);
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), shape, false);
return {context.mark_node(reshape)};
};

Expand Down
2 changes: 1 addition & 1 deletion src/frontends/pytorch/src/op/roll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace ov::op;
OutputVector translate_roll(const NodeContext& context) {
num_inputs_check(context, 2, 3);
const auto data = context.get_input(0);
const auto shifts = context.get_input(1);
const auto shifts = get_input_concat_if_list(context, 1);
Output<Node> axes;
bool on_flattened = context.input_is_none(2);
if (!on_flattened) {
Expand Down
10 changes: 10 additions & 0 deletions src/frontends/pytorch/src/op/slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <climits>

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
Expand Down Expand Up @@ -51,6 +52,9 @@ OutputVector translate_slice_common(const NodeContext& context,
if (start.get_partial_shape().rank().is_dynamic() || start.get_partial_shape().rank().get_length() == 0) {
start = context.mark_node(std::make_shared<v1::Reshape>(start, dims_1d_shape, false));
}
if (const auto start_const = ov::util::get_constant_from_source(start)) {
start = start_const;
}
} else {
start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
}
Expand All @@ -63,6 +67,9 @@ OutputVector translate_slice_common(const NodeContext& context,
(!(end.get_partial_shape().rank().is_dynamic()) && end.get_partial_shape().rank().get_length() == 0)) {
end = context.mark_node(std::make_shared<v1::Reshape>(end, dims_1d_shape, false));
}
if (const auto end_const = ov::util::get_constant_from_source(end)) {
end = end_const;
}
} else {
end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {INT_MAX}));
}
Expand All @@ -72,6 +79,9 @@ OutputVector translate_slice_common(const NodeContext& context,
if (step.get_partial_shape().rank().is_dynamic() || step.get_partial_shape().rank().get_length() == 0) {
step = context.mark_node(std::make_shared<v1::Reshape>(step, dims_1d_shape, false));
}
if (const auto step_const = ov::util::get_constant_from_source(step)) {
step = step_const;
}
} else {
step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
}
Expand Down
16 changes: 11 additions & 5 deletions src/frontends/pytorch/src/op/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include "openvino/op/transpose.hpp"

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
Expand All @@ -29,8 +29,9 @@ using namespace ov::op;

OutputVector translate_transpose(const NodeContext& context) {
num_inputs_check(context, 3, 3);
auto data = context.get_input(0);
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
std::tie(std::ignore, rank) = get_shape_rank(context, data, true);
auto dim0_node = get_input_as_i32(context, 1);
auto dim1_node = get_input_as_i32(context, 2);
dim0_node = normalize_axis(context, dim0_node, rank);
Expand All @@ -44,10 +45,15 @@ OutputVector translate_transpose(const NodeContext& context) {
auto dim1_node_ = std::make_shared<v0::Unsqueeze>(dim1_node, axis_0);
auto indices = std::make_shared<v0::Concat>(OutputVector{dim0_node_, dim1_node_}, 0);
auto updates = std::make_shared<v0::Concat>(OutputVector{dim1_node_, dim0_node_}, 0);
auto scatter = std::make_shared<v3::ScatterElementsUpdate>(range, indices, updates, axis_0);
context.mark_nodes({start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter});
Output<Node> scatter = std::make_shared<v3::ScatterElementsUpdate>(range, indices, updates, axis_0);
if (const auto scatter_const = ov::util::get_constant_from_source(scatter)) {
scatter = context.mark_node(scatter_const);
} else {
context.mark_nodes(
{start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter.get_node_shared_ptr()});
}

return {context.mark_node(std::make_shared<v1::Transpose>(context.get_input(0), scatter))};
return {context.mark_node(std::make_shared<v1::Transpose>(data, scatter))};
};

OutputVector translate_t(const NodeContext& context) {
Expand Down
Loading
Loading