From ea14d296686adc60b8d22a63f7dbcc681a081d64 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 11 Oct 2024 21:42:35 +0200 Subject: [PATCH] [PT FE] Support aten::unsafe_chunk op (#26931) ### Details: - *Support `aten::unsafe_chunk` op* ### Tickets: - *ticket-id* --- src/frontends/pytorch/src/op/getitem.cpp | 7 +- src/frontends/pytorch/src/op_table.cpp | 1 + .../src/transforms/aten_cat_replacer.cpp | 6 +- .../src/transforms/aten_getitem_replacer.cpp | 2 +- .../transforms/aten_index_put_replacer.cpp | 7 +- .../index_loop_getitem_replacer.cpp | 2 +- .../transforms/prim_list_unpack_replacer.cpp | 2 +- src/frontends/pytorch/src/utils.cpp | 15 +++ src/frontends/pytorch/src/utils.hpp | 2 + tests/layer_tests/pytorch_tests/test_chunk.py | 104 +++++++++++------- 10 files changed, 93 insertions(+), 55 deletions(-) diff --git a/src/frontends/pytorch/src/op/getitem.cpp b/src/frontends/pytorch/src/op/getitem.cpp index c28df052c6d362..c6f0d52c924895 100644 --- a/src/frontends/pytorch/src/op/getitem.cpp +++ b/src/frontends/pytorch/src/op/getitem.cpp @@ -23,10 +23,9 @@ OutputVector translate_getitem(const NodeContext& context) { PYTORCH_OP_CONVERSION_CHECK(!idx_type.is(), "String index in aten::__getitem__ means dict input, this is not supported."); if (ov::as_type_ptr(input.get_node_shared_ptr())) { - PYTORCH_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"), - "special case for aten::__getitem__"); - PYTORCH_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::chunk"), - "special case for aten::__getitem__"); + PYTORCH_OP_CONVERSION_CHECK( + !cast_fw_node(input.get_node_shared_ptr(), {"aten::split", "aten::chunk", "aten::unsafe_chunk"}), + "special case for aten::__getitem__"); const auto&& list_elems = get_list_as_outputs(input); auto getitem_idx = context.const_input(1); if (getitem_idx < 0) { diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 141e5b02ad8d25..b68c182e17ee2a 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -663,6 +663,7 @@ const std::unordered_map get_supported_ops_ts() { // aten::unbind - Supported in limited set of patterns {"aten::unflatten", op::translate_unflatten}, {"aten::unfold", op::translate_unfold}, + // aten::unsafe_chunk - Supported in limited set of patterns {"aten::unsqueeze", op::quantizable_op>}, {"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d}, {"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d}, diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index 1f31c75e6ae6c8..e42184ab6fad7f 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -37,11 +37,7 @@ AtenCatToConcat::AtenCatToConcat() { auto aten_cat = ov::pass::pattern::wrap_type(); ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { - auto cat = cast_fw_node(m.get_match_root(), "aten::cat"); - if (!cat) - cat = cast_fw_node(m.get_match_root(), "aten::concat"); - if (!cat) - cat = cast_fw_node(m.get_match_root(), "quantized::cat"); + auto cat = cast_fw_node(m.get_match_root(), {"aten::cat", "aten::concat", "quantized::cat"}); if (!cat) return false; diff --git a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp index 25be27e9939204..15852237c29f7f 100644 --- a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp @@ -124,7 +124,7 @@ AtenGetItemReplacer::AtenGetItemReplacer() { auto gather = rg.make(input_concat, getitem_idx, zero); replace_node(getitem, gather); } - } else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { + } else if (auto chunk = cast_fw_node(input_node, {"aten::chunk", "aten::unsafe_chunk"})) { auto input_tensor = chunk->get_input_source_output(0); auto chunks_i32 = chunk->get_input_source_output(1); auto dim_i32 = chunk->get_input_source_output(2); diff --git a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp index 587ff587bd333e..3351644ee9f196 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp @@ -48,12 +48,9 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { auto index_op = ov::pass::pattern::wrap_type(); ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { - auto index_op = cast_fw_node(m.get_match_root(), "aten::index_put_"); + auto index_op = cast_fw_node(m.get_match_root(), {"aten::index_put_", "aten.index_put.default"}); if (!index_op) { - index_op = cast_fw_node(m.get_match_root(), "aten.index_put.default"); - if (!index_op) { - return false; - } + return false; } NodeVector rt_copy_from; ov::pass::NodeRegistry rg; diff --git a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp index 4aae4d6e2a35dc..41d48c06f06332 100644 --- a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp @@ -52,7 +52,7 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() { size_t chunk_idx = 0; auto loop_inputs = loop_op->input_values(); for (size_t i = 1; i < loop_inputs.size(); i++) { - if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), "aten::chunk")) { + if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), {"aten::chunk", "aten::unsafe_chunk"})) { chunk_op = loop_inputs.at(i).get_node_shared_ptr(); chunk_idx = i; break; diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index 5560c9ff225e9d..67e2b5b37ecaac 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -88,7 +88,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { replace_node(list_unpack, split); return true; - } else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { + } else if (auto chunk = cast_fw_node(input_node, {"aten::chunk", "aten::unsafe_chunk"})) { if (list_unpack->get_output_size() == 1) { list_unpack->output(0).replace(input_node->input_value(0)); return true; diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 11e62baf2b606b..852de6e90fa25b 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -378,6 +378,21 @@ std::shared_ptr cast_fw_node(std::shared_ptr return fw_node; } +std::shared_ptr cast_fw_node(std::shared_ptr node, + std::initializer_list types) { + auto fw_node = std::dynamic_pointer_cast(node); + if (!fw_node) { + return nullptr; + } + const auto& attrs = fw_node->get_attrs(); + for (auto type : types) { + if (attrs.find(PtFrameworkNode::op_type_key) != attrs.end() && attrs.at(PtFrameworkNode::op_type_key) == type) { + return fw_node; + } + } + return nullptr; +} + std::shared_ptr make_list_construct(const ov::OutputVector& inputs) { auto list_construct = std::make_shared(inputs, inputs.size()); ov::op::util::FrameworkNodeAttrs attrs; diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 434cc109d022aa..f4104a83ae3252 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -80,6 +80,8 @@ OutputVector make_framework_node_ignore_bodies(const NodeContext& context, const OutputVector make_framework_node(const NodeContext& context, const std::string& exception); std::shared_ptr cast_fw_node(std::shared_ptr node, const std::string& type); +std::shared_ptr cast_fw_node(std::shared_ptr node, + std::initializer_list types); std::shared_ptr make_list_construct(const ov::OutputVector& inputs); diff --git a/tests/layer_tests/pytorch_tests/test_chunk.py b/tests/layer_tests/pytorch_tests/test_chunk.py index 7b8cc3f890edb4..561d6cc6e84b4a 100644 --- a/tests/layer_tests/pytorch_tests/test_chunk.py +++ b/tests/layer_tests/pytorch_tests/test_chunk.py @@ -7,54 +7,71 @@ from pytorch_layer_test_class import PytorchLayerTest + class aten_chunk_2(torch.nn.Module): - def __init__(self, dim) -> None: + def __init__(self, dim, unsafe=False) -> None: torch.nn.Module.__init__(self) self.dim = dim + self.chunk_op = torch.chunk + if unsafe: + self.chunk_op = torch._VF.unsafe_chunk def forward(self, input_tensor): - a,b = torch.chunk(input_tensor, - chunks = 2, - dim = self.dim - ) - return a,b + a, b = self.chunk_op(input_tensor, + chunks=2, + dim=self.dim + ) + return a, b + class aten_chunk_3(torch.nn.Module): - def __init__(self, dim) -> None: + def __init__(self, dim, unsafe=False) -> None: torch.nn.Module.__init__(self) self.dim = dim + self.chunk_op = torch.chunk + if unsafe: + self.chunk_op = torch._VF.unsafe_chunk def forward(self, input_tensor): - a,b,c = torch.chunk(input_tensor, - chunks = 3, - dim = self.dim - ) - return a,b,c + a, b, c = self.chunk_op(input_tensor, + chunks=3, + dim=self.dim + ) + return a, b, c + class aten_chunk_4(torch.nn.Module): - def __init__(self, dim) -> None: + def __init__(self, dim, unsafe=False) -> None: torch.nn.Module.__init__(self) self.dim = dim + self.chunk_op = torch.chunk + if unsafe: + self.chunk_op = torch._VF.unsafe_chunk def forward(self, input_tensor): - a,b,c,d = torch.chunk(input_tensor, - chunks = 4, - dim = self.dim - ) - return a,b,c,d + a, b, c, d = self.chunk_op(input_tensor, + chunks=4, + dim=self.dim + ) + return a, b, c, d + class aten_chunk_getitem(torch.nn.Module): - def __init__(self, chunks, dim, idx) -> None: + def __init__(self, chunks, dim, idx, unsafe=False) -> None: torch.nn.Module.__init__(self) self.chunks = chunks self.dim = dim self.idx = idx + self.chunk_op = torch.chunk + if unsafe: + self.chunk_op = torch._VF.unsafe_chunk def forward(self, input_tensor): - return torch.chunk(input_tensor, - chunks = self.chunks, - dim = self.dim - )[self.idx] + return self.chunk_op(input_tensor, + chunks=self.chunks, + dim=self.dim + )[self.idx] + class TestChunk(PytorchLayerTest): def _prepare_input(self): @@ -68,23 +85,24 @@ def _prepare_input(self): ]) @pytest.mark.parametrize("chunks", [ # Does not work for 1 - no list_unpack present in the graph - # 1, + # 1, 2, 3, 4 ]) + @pytest.mark.parametrize("unsafe", [True, False]) @pytest.mark.nightly @pytest.mark.precommit - def test_chunk(self, input_shape, chunks, ie_device, precision, ir_version): + def test_chunk(self, input_shape, chunks, unsafe, ie_device, precision, ir_version): self.input_shape = input_shape - + for dim, dim_shape in enumerate(input_shape): chunk_size = dim_shape // chunks chunk_size += 1 if dim_shape % chunks > 0 else 0 output_chunks = dim_shape // chunk_size output_chunks += 1 if dim_shape % chunk_size > 0 else 0 - + if output_chunks == 2: cls = aten_chunk_2 elif output_chunks == 3: @@ -92,9 +110,11 @@ def test_chunk(self, input_shape, chunks, ie_device, precision, ir_version): elif output_chunks == 4: cls = aten_chunk_4 - self._test(cls(dim), None, "aten::chunk", - ie_device, precision, ir_version, dynamic_shapes = False, freeze_model=True, trace_model=True) - + self._test(cls(dim, unsafe), None, + "aten::unsafe_chunk" if unsafe else "aten::chunk", + ie_device, precision, ir_version, dynamic_shapes=False, + freeze_model=True, trace_model=True) + @pytest.mark.parametrize("input_shape", [ (4, 4), (10, 13, 11), @@ -105,9 +125,10 @@ def test_chunk(self, input_shape, chunks, ie_device, precision, ir_version): 3, 4 ]) + @pytest.mark.parametrize("unsafe", [True, False]) @pytest.mark.nightly @pytest.mark.precommit - def test_chunk_getitem(self, input_shape, chunks, ie_device, precision, ir_version): + def test_chunk_getitem(self, input_shape, chunks, unsafe, ie_device, precision, ir_version): self.input_shape = input_shape for dim, dim_shape in enumerate(input_shape): @@ -118,18 +139,22 @@ def test_chunk_getitem(self, input_shape, chunks, ie_device, precision, ir_versi output_chunks += 1 if dim_shape % chunk_size > 0 else 0 for idx in [0, 1, output_chunks - 1]: - self._test(aten_chunk_getitem(chunks, dim, idx), None, "aten::chunk", - ie_device, precision, ir_version) + self._test(aten_chunk_getitem(chunks, dim, idx, unsafe), None, + "aten::unsafe_chunk" if unsafe else "aten::chunk", + ie_device, precision, ir_version) class aten_chunk_loop_getitem(torch.nn.Module): - def __init__(self, num_chunks) -> None: + def __init__(self, num_chunks, unsafe=False) -> None: torch.nn.Module.__init__(self) self.num_chunks = num_chunks + self.chunk_op = torch.chunk + if unsafe: + self.chunk_op = torch._VF.unsafe_chunk def forward(self, input_tensor): - chunks = torch.chunk(torch.arange( - input_tensor.shape[0]), self.num_chunks) + x = torch.arange(input_tensor.shape[0]) + chunks = self.chunk_op(x, self.num_chunks) for inds in chunks: input_tensor[inds] *= 10 @@ -151,10 +176,13 @@ def _prepare_input(self): 3, 4 ]) + @pytest.mark.parametrize("unsafe", [True, False]) @pytest.mark.nightly @pytest.mark.precommit - def test_chunk_loop_getitem(self, input_shape, chunks, ie_device, precision, ir_version): + def test_chunk_loop_getitem(self, input_shape, chunks, unsafe, ie_device, precision, ir_version): self.input_shape = input_shape - self._test(aten_chunk_loop_getitem(chunks), None, ["aten::chunk", "prim::Loop", "aten::__getitem__"], + chunk_op = "aten::unsafe_chunk" if unsafe else "aten::chunk" + self._test(aten_chunk_loop_getitem(chunks, unsafe), None, + [chunk_op, "prim::Loop", "aten::__getitem__"], ie_device, precision, ir_version)