From 18ba7ba34b7de070834663ae0dfd53f51d1c0e34 Mon Sep 17 00:00:00 2001 From: muhd360 Date: Mon, 9 Sep 2024 14:58:27 +0530 Subject: [PATCH 1/9] Added index_fill for pytorch frontend and tests --- .../pytorch_tests/test_index_fill_.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_index_fill_.py diff --git a/tests/layer_tests/pytorch_tests/test_index_fill_.py b/tests/layer_tests/pytorch_tests/test_index_fill_.py new file mode 100644 index 00000000000000..f85f0306ea5984 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_index_fill_.py @@ -0,0 +1,85 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestIndexFill(PytorchLayerTest): + def _prepare_input(self): + return (self.input_tensor, self.values) + + def create_model(self, dim, index): + class aten_index_fill_(torch.nn.Module): + def __init__(self, dim, index): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, input_tensor, values): + input_tensor.index_fill_(self.dim, self.index, values) + return input_tensor + + ref_net = None + + return aten_index_fill_(dim, index), ref_net, "aten::index_fill_" + + @pytest.mark.parametrize( + "input_data", + ( + { + "input_shape": [1], + "dim": 0, + "values_shape": [1], + "index": torch.tensor([0], dtype=torch.long) + }, + { + "input_shape": [10], + "dim": 0, + "values_shape": [5], + "index": torch.tensor([2, 3, 6, 7, 1], dtype=torch.long) + }, + { + "input_shape": [3, 3], + "dim": 0, + "values_shape": [2, 3], + "index": torch.tensor([2, 0], dtype=torch.long) + }, + { + "input_shape": [4, 3, 5], + "dim": 1, + "values_shape": [4, 2, 5], + "index": torch.tensor([1, 0], dtype=torch.long) + }, + { + "input_shape": [5, 6, 7, 8], + "dim": -2, + "values_shape": [5, 6, 4, 8], + "index": torch.tensor([5, 0, 6, 3], dtype=torch.long) + }, + { + "input_shape": [5, 6, 7, 8], + "dim": -3, + "values_shape": [5, 3, 7, 8], + "index": torch.tensor([2, 0, 1], dtype=torch.long) + }, + { + "input_shape": [5, 6, 7, 8], + "dim": 3, + "values_shape": [5, 6, 7, 5], + "index": torch.tensor([2, 6, 0, 4, 1], dtype=torch.long) + }, + ), + ) + @pytest.mark.nightly + @pytest.mark.precommit + def test_index_copy_single_index(self, ie_device, precision, ir_version, input_data): + self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32) + self.values = np.random.randn(*input_data["values_shape"]).astype(np.float32) + index = input_data["index"] + dim = input_data["dim"] + self._test(*self.create_model(dim, index), ie_device, precision, ir_version) From bf2e3f600117014f965946ea18287afe1e3302e4 Mon Sep 17 00:00:00 2001 From: muhd360 Date: Mon, 9 Sep 2024 15:17:51 +0530 Subject: [PATCH 2/9] index_fill method --- src/frontends/pytorch/src/op/index_fill_.cpp | 63 ++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/frontends/pytorch/src/op/index_fill_.cpp diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp new file mode 100644 index 00000000000000..bb1788929f7f2f --- /dev/null +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -0,0 +1,63 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "utils.hpp" + +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_index_copy_(const NodeContext& context) { + // aten::index_fill_(self, dim, index, value) → Tensor + num_inputs_check(context, 4, 4); + auto input = context.get_input(0); + auto dim = context.get_input(1); + auto index = context.get_input(2); + auto value = context.get_input(3); + + auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {value}); + + Output tensor_rank = std::get<1>(get_shape_rank(context, input, true)); + auto tensor_rank_correct_type = context.mark_node(std::make_shared(tensor_rank, dim)); + auto positive_dim = normalize_axis(context, dim, tensor_rank_correct_type); + + // begin the computation + //indx_cpy(dim,idx_def,tensor_def) + //indx_fill(dim,idx_def,val) + auto tensor_shape = context.mark_node(std::make_shared(input, element::i32)); + auto dim_vec = context.mark_node(std::make_shared(positive_dim, const_1_vec, false)); + auto broadcasted_index = context.mark_node(std::make_shared(index, tensor_shape, dim_vec)); + //reshaping steps + + //so we need to create a tensor with the same shape as the input tensor + //(val tensor is broadcasted to the shape of the input tensor) then index_copy + //val tensor->tensor of shape of input tensor with all values v + + + + + auto result = + context.mark_node(std::make_shared(input, broadcasted_index, index, dim)); + return {result}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov From 55c2ae39addec09a9e7624f0b9b7315e6157cb6b Mon Sep 17 00:00:00 2001 From: Muhammad sheikh Date: Mon, 9 Sep 2024 15:25:27 +0530 Subject: [PATCH 3/9] Update src/frontends/pytorch/src/op/index_fill_.cpp Co-authored-by: Maxim Vafin --- src/frontends/pytorch/src/op/index_fill_.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp index bb1788929f7f2f..981ad212a65e8e 100644 --- a/src/frontends/pytorch/src/op/index_fill_.cpp +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -23,7 +23,7 @@ namespace frontend { namespace pytorch { namespace op { -OutputVector translate_index_copy_(const NodeContext& context) { +OutputVector translate_index_fill_(const NodeContext& context) { // aten::index_fill_(self, dim, index, value) → Tensor num_inputs_check(context, 4, 4); auto input = context.get_input(0); From a9881052458da6ddfcf6db99bffcb55f42058cd3 Mon Sep 17 00:00:00 2001 From: muhd360 Date: Mon, 9 Sep 2024 15:35:10 +0530 Subject: [PATCH 4/9] updated the conversion function in the list of operators --- src/frontends/pytorch/src/op_table.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 1e4ecfc1e1367f..981e3eb2435961 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -483,6 +483,7 @@ const std::unordered_map get_supported_ops_ts() { // aten::imag - Supported in limited set of patterns // aten::index - Supported in limited set of patterns {"aten::index_copy_", op::inplace_op}, + {"aten::index_fill_", op::inplace_op}, {"aten::index_put_", op::inplace_op}, {"aten::index_add", op::translate_index_add}, {"aten::index_select", op::translate_index_select}, From ea8e6f58a191a6fcfc88dd2da21a9782e950f1ae Mon Sep 17 00:00:00 2001 From: muhd360 Date: Tue, 10 Sep 2024 08:32:48 +0530 Subject: [PATCH 5/9] op convertor --- src/frontends/pytorch/src/op_table.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 981e3eb2435961..1dffa6756f6e05 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -110,6 +110,7 @@ OP_CONVERTER(translate_im2col); OP_CONVERTER(translate_index); OP_CONVERTER(translate_index_add); OP_CONVERTER(translate_index_copy_); +OP_CONVERTER(translate_index_fill_); OP_CONVERTER(translate_index_put_); OP_CONVERTER(translate_index_select); OP_CONVERTER(translate_instance_norm); From c58b7766823735cc0ed7954b9e39fa1cfea72d13 Mon Sep 17 00:00:00 2001 From: muhd360 Date: Wed, 11 Sep 2024 23:56:17 +0530 Subject: [PATCH 6/9] clang-format --- src/frontends/pytorch/src/op/index_fill_.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp index 981ad212a65e8e..7b58a6cf695beb 100644 --- a/src/frontends/pytorch/src/op/index_fill_.cpp +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -24,7 +24,7 @@ namespace pytorch { namespace op { OutputVector translate_index_fill_(const NodeContext& context) { - // aten::index_fill_(self, dim, index, value) → Tensor + num_inputs_check(context, 4, 4); auto input = context.get_input(0); auto dim = context.get_input(1); @@ -37,18 +37,11 @@ OutputVector translate_index_fill_(const NodeContext& context) { auto tensor_rank_correct_type = context.mark_node(std::make_shared(tensor_rank, dim)); auto positive_dim = normalize_axis(context, dim, tensor_rank_correct_type); - // begin the computation - //indx_cpy(dim,idx_def,tensor_def) - //indx_fill(dim,idx_def,val) + auto tensor_shape = context.mark_node(std::make_shared(input, element::i32)); auto dim_vec = context.mark_node(std::make_shared(positive_dim, const_1_vec, false)); auto broadcasted_index = context.mark_node(std::make_shared(index, tensor_shape, dim_vec)); - //reshaping steps - - //so we need to create a tensor with the same shape as the input tensor - //(val tensor is broadcasted to the shape of the input tensor) then index_copy - //val tensor->tensor of shape of input tensor with all values v - + From 3a066218a408ae901c4ac6783f06d6fbc4c3a471 Mon Sep 17 00:00:00 2001 From: muhd360 Date: Sat, 14 Sep 2024 21:57:51 +0530 Subject: [PATCH 7/9] CODE STYLE --- src/frontends/pytorch/src/op/index_fill_.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp index 7b58a6cf695beb..3c81b390984293 100644 --- a/src/frontends/pytorch/src/op/index_fill_.cpp +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -24,7 +24,6 @@ namespace pytorch { namespace op { OutputVector translate_index_fill_(const NodeContext& context) { - num_inputs_check(context, 4, 4); auto input = context.get_input(0); auto dim = context.get_input(1); @@ -37,16 +36,11 @@ OutputVector translate_index_fill_(const NodeContext& context) { auto tensor_rank_correct_type = context.mark_node(std::make_shared(tensor_rank, dim)); auto positive_dim = normalize_axis(context, dim, tensor_rank_correct_type); - auto tensor_shape = context.mark_node(std::make_shared(input, element::i32)); auto dim_vec = context.mark_node(std::make_shared(positive_dim, const_1_vec, false)); auto broadcasted_index = context.mark_node(std::make_shared(index, tensor_shape, dim_vec)); - - - - auto result = - context.mark_node(std::make_shared(input, broadcasted_index, index, dim)); + auto result = context.mark_node(std::make_shared(input, broadcasted_index, index, dim)); return {result}; }; From 18771c943a56d67e69c22f9e6db3c3e40da93177 Mon Sep 17 00:00:00 2001 From: muhd360 Date: Fri, 20 Sep 2024 21:17:16 +0530 Subject: [PATCH 8/9] node --- src/frontends/pytorch/src/op/index_fill_.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp index 3c81b390984293..87eb2d7f1a8f60 100644 --- a/src/frontends/pytorch/src/op/index_fill_.cpp +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -39,8 +39,13 @@ OutputVector translate_index_fill_(const NodeContext& context) { auto tensor_shape = context.mark_node(std::make_shared(input, element::i32)); auto dim_vec = context.mark_node(std::make_shared(positive_dim, const_1_vec, false)); auto broadcasted_index = context.mark_node(std::make_shared(index, tensor_shape, dim_vec)); + // Assuming v is your value node and broadcasted_index is already created + auto index_shape = context.mark_node(std::make_shared(broadcasted_index, element::i32)); - auto result = context.mark_node(std::make_shared(input, broadcasted_index, index, dim)); + // Create a tensor filled with the value of v + auto filled_with_v = context.mark_node(std::make_shared(value, index_shape)); + + auto result = context.mark_node(std::make_shared(input, filled_with_v, index, dim)); return {result}; }; From 1f8dbf436f2f70653b2ca3a95b27b6b308f768ec Mon Sep 17 00:00:00 2001 From: muhd360 Date: Sat, 21 Sep 2024 12:06:53 +0530 Subject: [PATCH 9/9] format --- src/frontends/pytorch/src/op/index_fill_.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp index 87eb2d7f1a8f60..6e8aaa5384d97a 100644 --- a/src/frontends/pytorch/src/op/index_fill_.cpp +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -30,7 +30,7 @@ OutputVector translate_index_fill_(const NodeContext& context) { auto index = context.get_input(2); auto value = context.get_input(3); - auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {value}); + auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {0}); Output tensor_rank = std::get<1>(get_shape_rank(context, input, true)); auto tensor_rank_correct_type = context.mark_node(std::make_shared(tensor_rank, dim));