Skip to content

Commit

Permalink
[PT FE] Support negative orders in aten::permute (#28459)
Browse files Browse the repository at this point in the history
### Details:
 - *Support negative orders in `aten::permute`*

### Tickets:
 - *ticket-id*

---------

Signed-off-by: Maxim Vafin <[email protected]>
  • Loading branch information
mvafin authored Jan 15, 2025
1 parent 2f5af17 commit 3503fe6
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
31 changes: 31 additions & 0 deletions src/frontends/pytorch/src/op/permute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/core/validation_util.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/transpose.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_permute(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
auto order = get_input_concat_if_list(context, 1);
auto rank = std::get<1>(get_shape_rank(context, data));
auto rank_converted = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(rank, order));
auto order_normalized = normalize_axis(context, order, rank_converted);
if (const auto order_const = ov::util::get_constant_from_source(order_normalized)) {
order_normalized = order_const;
}
return {context.mark_node(std::make_shared<ov::op::v1::Transpose>(data, order_normalized))};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
5 changes: 3 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ OP_CONVERTER(translate_outer);
OP_CONVERTER(translate_pack_padded_sequence);
OP_CONVERTER(translate_pad);
OP_CONVERTER(translate_pad_packed_sequence);
OP_CONVERTER(translate_permute);
OP_CONVERTER(translate_pairwise_distance);
OP_CONVERTER(translate_pixel_shuffle);
OP_CONVERTER(translate_pixel_unshuffle);
Expand Down Expand Up @@ -589,7 +590,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::outer", op::translate_outer},
{"aten::pad", op::translate_pad},
{"aten::pairwise_distance", op::translate_pairwise_distance},
{"aten::permute", op::translate_1to1_match_2_inputs<opset10::Transpose>},
{"aten::permute", op::translate_permute},
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
{"aten::pixel_unshuffle", op::translate_pixel_unshuffle},
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
Expand Down Expand Up @@ -918,7 +919,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.ones.default", op::translate_ones_fx},
{"aten.ones.names", op::translate_ones_fx},
{"aten.ones_like.default", op::translate_ones_like_fx},
{"aten.permute.default", op::translate_1to1_match_2_inputs<opset10::Transpose>},
{"aten.permute.default", op::translate_permute},
{"aten.pow.Scalar", op::translate_pow},
{"aten.pow.Tensor_Scalar", op::translate_pow},
{"aten.pow.Tensor_Tensor", op::translate_pow},
Expand Down
17 changes: 2 additions & 15 deletions src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

#include "openvino/core/rt_info.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/abs.hpp"
#include "openvino/op/adaptive_avg_pool.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
Expand All @@ -17,11 +15,9 @@
#include "openvino/op/multiply.hpp"
#include "openvino/op/random_uniform.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/roll.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/matcher.hpp"
Expand All @@ -47,8 +43,6 @@ ListConstructReplacer::ListConstructReplacer() {
const auto& select_op = pattern::wrap_type<v1::Select>({pattern::any_input(), pattern::any_input(), list});
// replace list construct for aten::repeat(tensor, prim::ListConstruct(shapes)))
const auto& tile_op = pattern::wrap_type<v0::Tile>({pattern::any_input(), list});
// replace aten::permute(tensor, prim::ListConstruct)
const auto& transpose_op = pattern::wrap_type<v1::Transpose>({pattern::any_input(), list});
// aten::split_with_sizes case
const auto& vsplit_op = pattern::wrap_type<v1::VariadicSplit>({pattern::any_input(), pattern::any_input(), list});
// aten::upsample... case inside the body when body was removed
Expand All @@ -58,15 +52,8 @@ ListConstructReplacer::ListConstructReplacer() {
pattern::wrap_type<v11::Interpolate>({pattern::any_input(), interpolate_mul_op, pattern::any_input()});
// aten::randint case
const auto& rand_op = pattern::wrap_type<v8::RandomUniform>({list, pattern::any_input(), pattern::any_input()});
const auto& lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{broadcast_op,
shape_of_op,
equal_op,
select_op,
tile_op,
transpose_op,
vsplit_op,
interpolate_op,
rand_op});
const auto& lc_pattern = std::make_shared<pattern::op::Or>(
OutputVector{broadcast_op, shape_of_op, equal_op, select_op, tile_op, vsplit_op, interpolate_op, rand_op});

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
Expand Down
5 changes: 3 additions & 2 deletions tests/layer_tests/pytorch_tests/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ def forward(self, x):

return aten_permute(order), ref_net, "aten::permute"

@pytest.mark.parametrize("order", [[0, 2, 3, 1], [0, 3, 1, 2]])
@pytest.mark.parametrize("order", [[0, 2, 3, 1], [0, 3, 1, 2], [0, -1, 1, -2]])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
def test_permute(self, order, ie_device, precision, ir_version):
self._test(*self.create_model(order), ie_device, precision, ir_version)


class TestPermuteList(PytorchLayerTest):
def _prepare_input(self, permute_shape):
import numpy as np
Expand All @@ -55,6 +56,6 @@ def forward(self, x, y):
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
def test_permute(self, order, ie_device, precision, ir_version):
def test_permute_list(self, order, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version,
kwargs_to_prepare_input={"permute_shape": order}, dynamic_shapes=ie_device != "GPU")

0 comments on commit 3503fe6

Please sign in to comment.