Skip to content

Commit

Permalink
[PT FE] Support aten::unsafe_chunk op (#26931)
Browse files Browse the repository at this point in the history
### Details:
 - *Support `aten::unsafe_chunk` op*

### Tickets:
 - *ticket-id*
  • Loading branch information
mvafin authored Oct 11, 2024
1 parent e2b09ea commit ea14d29
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 55 deletions.
7 changes: 3 additions & 4 deletions src/frontends/pytorch/src/op/getitem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ OutputVector translate_getitem(const NodeContext& context) {
PYTORCH_OP_CONVERSION_CHECK(!idx_type.is<type::Str>(),
"String index in aten::__getitem__ means dict input, this is not supported.");
if (ov::as_type_ptr<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
PYTORCH_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::split"),
"special case for aten::__getitem__");
PYTORCH_OP_CONVERSION_CHECK(!cast_fw_node(input.get_node_shared_ptr(), "aten::chunk"),
"special case for aten::__getitem__");
PYTORCH_OP_CONVERSION_CHECK(
!cast_fw_node(input.get_node_shared_ptr(), {"aten::split", "aten::chunk", "aten::unsafe_chunk"}),
"special case for aten::__getitem__");
const auto&& list_elems = get_list_as_outputs(input);
auto getitem_idx = context.const_input<int64_t>(1);
if (getitem_idx < 0) {
Expand Down
1 change: 1 addition & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
// aten::unbind - Supported in limited set of patterns
{"aten::unflatten", op::translate_unflatten},
{"aten::unfold", op::translate_unfold},
// aten::unsafe_chunk - Supported in limited set of patterns
{"aten::unsqueeze", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
{"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d},
{"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d},
Expand Down
6 changes: 1 addition & 5 deletions src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ AtenCatToConcat::AtenCatToConcat() {
auto aten_cat = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();

ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto cat = cast_fw_node(m.get_match_root(), "aten::cat");
if (!cat)
cat = cast_fw_node(m.get_match_root(), "aten::concat");
if (!cat)
cat = cast_fw_node(m.get_match_root(), "quantized::cat");
auto cat = cast_fw_node(m.get_match_root(), {"aten::cat", "aten::concat", "quantized::cat"});
if (!cat)
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
auto gather = rg.make<v8::Gather>(input_concat, getitem_idx, zero);
replace_node(getitem, gather);
}
} else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
} else if (auto chunk = cast_fw_node(input_node, {"aten::chunk", "aten::unsafe_chunk"})) {
auto input_tensor = chunk->get_input_source_output(0);
auto chunks_i32 = chunk->get_input_source_output(1);
auto dim_i32 = chunk->get_input_source_output(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
auto index_op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();

ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto index_op = cast_fw_node(m.get_match_root(), "aten::index_put_");
auto index_op = cast_fw_node(m.get_match_root(), {"aten::index_put_", "aten.index_put.default"});
if (!index_op) {
index_op = cast_fw_node(m.get_match_root(), "aten.index_put.default");
if (!index_op) {
return false;
}
return false;
}
NodeVector rt_copy_from;
ov::pass::NodeRegistry rg;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() {
size_t chunk_idx = 0;
auto loop_inputs = loop_op->input_values();
for (size_t i = 1; i < loop_inputs.size(); i++) {
if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), "aten::chunk")) {
if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), {"aten::chunk", "aten::unsafe_chunk"})) {
chunk_op = loop_inputs.at(i).get_node_shared_ptr();
chunk_idx = i;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
replace_node(list_unpack, split);

return true;
} else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
} else if (auto chunk = cast_fw_node(input_node, {"aten::chunk", "aten::unsafe_chunk"})) {
if (list_unpack->get_output_size() == 1) {
list_unpack->output(0).replace(input_node->input_value(0));
return true;
Expand Down
15 changes: 15 additions & 0 deletions src/frontends/pytorch/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,21 @@ std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node>
return fw_node;
}

std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node,
std::initializer_list<std::string> types) {
auto fw_node = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(node);
if (!fw_node) {
return nullptr;
}
const auto& attrs = fw_node->get_attrs();
for (auto type : types) {
if (attrs.find(PtFrameworkNode::op_type_key) != attrs.end() && attrs.at(PtFrameworkNode::op_type_key) == type) {
return fw_node;
}
}
return nullptr;
}

std::shared_ptr<ov::Node> make_list_construct(const ov::OutputVector& inputs) {
auto list_construct = std::make_shared<ov::op::util::FrameworkNode>(inputs, inputs.size());
ov::op::util::FrameworkNodeAttrs attrs;
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ OutputVector make_framework_node_ignore_bodies(const NodeContext& context, const
OutputVector make_framework_node(const NodeContext& context, const std::string& exception);

std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node,
std::initializer_list<std::string> types);

std::shared_ptr<Node> make_list_construct(const ov::OutputVector& inputs);

Expand Down
104 changes: 66 additions & 38 deletions tests/layer_tests/pytorch_tests/test_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,71 @@

from pytorch_layer_test_class import PytorchLayerTest


class aten_chunk_2(torch.nn.Module):
def __init__(self, dim) -> None:
def __init__(self, dim, unsafe=False) -> None:
torch.nn.Module.__init__(self)
self.dim = dim
self.chunk_op = torch.chunk
if unsafe:
self.chunk_op = torch._VF.unsafe_chunk

def forward(self, input_tensor):
a,b = torch.chunk(input_tensor,
chunks = 2,
dim = self.dim
)
return a,b
a, b = self.chunk_op(input_tensor,
chunks=2,
dim=self.dim
)
return a, b


class aten_chunk_3(torch.nn.Module):
def __init__(self, dim) -> None:
def __init__(self, dim, unsafe=False) -> None:
torch.nn.Module.__init__(self)
self.dim = dim
self.chunk_op = torch.chunk
if unsafe:
self.chunk_op = torch._VF.unsafe_chunk

def forward(self, input_tensor):
a,b,c = torch.chunk(input_tensor,
chunks = 3,
dim = self.dim
)
return a,b,c
a, b, c = self.chunk_op(input_tensor,
chunks=3,
dim=self.dim
)
return a, b, c


class aten_chunk_4(torch.nn.Module):
def __init__(self, dim) -> None:
def __init__(self, dim, unsafe=False) -> None:
torch.nn.Module.__init__(self)
self.dim = dim
self.chunk_op = torch.chunk
if unsafe:
self.chunk_op = torch._VF.unsafe_chunk

def forward(self, input_tensor):
a,b,c,d = torch.chunk(input_tensor,
chunks = 4,
dim = self.dim
)
return a,b,c,d
a, b, c, d = self.chunk_op(input_tensor,
chunks=4,
dim=self.dim
)
return a, b, c, d


class aten_chunk_getitem(torch.nn.Module):
def __init__(self, chunks, dim, idx) -> None:
def __init__(self, chunks, dim, idx, unsafe=False) -> None:
torch.nn.Module.__init__(self)
self.chunks = chunks
self.dim = dim
self.idx = idx
self.chunk_op = torch.chunk
if unsafe:
self.chunk_op = torch._VF.unsafe_chunk

def forward(self, input_tensor):
return torch.chunk(input_tensor,
chunks = self.chunks,
dim = self.dim
)[self.idx]
return self.chunk_op(input_tensor,
chunks=self.chunks,
dim=self.dim
)[self.idx]


class TestChunk(PytorchLayerTest):
def _prepare_input(self):
Expand All @@ -68,33 +85,36 @@ def _prepare_input(self):
])
@pytest.mark.parametrize("chunks", [
# Does not work for 1 - no list_unpack present in the graph
# 1,
# 1,
2,
3,
4
])
@pytest.mark.parametrize("unsafe", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_chunk(self, input_shape, chunks, ie_device, precision, ir_version):
def test_chunk(self, input_shape, chunks, unsafe, ie_device, precision, ir_version):
self.input_shape = input_shape

for dim, dim_shape in enumerate(input_shape):
chunk_size = dim_shape // chunks
chunk_size += 1 if dim_shape % chunks > 0 else 0

output_chunks = dim_shape // chunk_size
output_chunks += 1 if dim_shape % chunk_size > 0 else 0

if output_chunks == 2:
cls = aten_chunk_2
elif output_chunks == 3:
cls = aten_chunk_3
elif output_chunks == 4:
cls = aten_chunk_4

self._test(cls(dim), None, "aten::chunk",
ie_device, precision, ir_version, dynamic_shapes = False, freeze_model=True, trace_model=True)

self._test(cls(dim, unsafe), None,
"aten::unsafe_chunk" if unsafe else "aten::chunk",
ie_device, precision, ir_version, dynamic_shapes=False,
freeze_model=True, trace_model=True)

@pytest.mark.parametrize("input_shape", [
(4, 4),
(10, 13, 11),
Expand All @@ -105,9 +125,10 @@ def test_chunk(self, input_shape, chunks, ie_device, precision, ir_version):
3,
4
])
@pytest.mark.parametrize("unsafe", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_chunk_getitem(self, input_shape, chunks, ie_device, precision, ir_version):
def test_chunk_getitem(self, input_shape, chunks, unsafe, ie_device, precision, ir_version):
self.input_shape = input_shape
for dim, dim_shape in enumerate(input_shape):

Expand All @@ -118,18 +139,22 @@ def test_chunk_getitem(self, input_shape, chunks, ie_device, precision, ir_versi
output_chunks += 1 if dim_shape % chunk_size > 0 else 0

for idx in [0, 1, output_chunks - 1]:
self._test(aten_chunk_getitem(chunks, dim, idx), None, "aten::chunk",
ie_device, precision, ir_version)
self._test(aten_chunk_getitem(chunks, dim, idx, unsafe), None,
"aten::unsafe_chunk" if unsafe else "aten::chunk",
ie_device, precision, ir_version)


class aten_chunk_loop_getitem(torch.nn.Module):
def __init__(self, num_chunks) -> None:
def __init__(self, num_chunks, unsafe=False) -> None:
torch.nn.Module.__init__(self)
self.num_chunks = num_chunks
self.chunk_op = torch.chunk
if unsafe:
self.chunk_op = torch._VF.unsafe_chunk

def forward(self, input_tensor):
chunks = torch.chunk(torch.arange(
input_tensor.shape[0]), self.num_chunks)
x = torch.arange(input_tensor.shape[0])
chunks = self.chunk_op(x, self.num_chunks)

for inds in chunks:
input_tensor[inds] *= 10
Expand All @@ -151,10 +176,13 @@ def _prepare_input(self):
3,
4
])
@pytest.mark.parametrize("unsafe", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_chunk_loop_getitem(self, input_shape, chunks, ie_device, precision, ir_version):
def test_chunk_loop_getitem(self, input_shape, chunks, unsafe, ie_device, precision, ir_version):
self.input_shape = input_shape

self._test(aten_chunk_loop_getitem(chunks), None, ["aten::chunk", "prim::Loop", "aten::__getitem__"],
chunk_op = "aten::unsafe_chunk" if unsafe else "aten::chunk"
self._test(aten_chunk_loop_getitem(chunks, unsafe), None,
[chunk_op, "prim::Loop", "aten::__getitem__"],
ie_device, precision, ir_version)

0 comments on commit ea14d29

Please sign in to comment.