diff --git a/src/frontends/pytorch/src/op/index_fill_.cpp b/src/frontends/pytorch/src/op/index_fill_.cpp new file mode 100644 index 00000000000000..a24f3fa2f5b1c7 --- /dev/null +++ b/src/frontends/pytorch/src/op/index_fill_.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "utils.hpp" + +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_index_fill_(const NodeContext& context) { + // aten::index_fill_(self, dim, index, value) --> Tensor + num_inputs_check(context, 4, 4); + auto input = context.get_input(0); + auto dim = context.get_input(1); + auto index = context.get_input(2); + auto value = context.get_input(3); + + auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {1}); + + auto tensor_rank = std::get<1>(get_shape_rank(context, input, false)); + auto tensor_rank_correct_type = context.mark_node(std::make_shared(tensor_rank, dim)); + auto dim_vec = normalize_axis(context, dim, tensor_rank_correct_type); + + // scalar to vec + auto value_vec = context.mark_node(std::make_shared(value, const_1_vec, false)); + + auto input_shape = std::get<0>(get_shape_rank(context, input, false)); + + auto index_shape = std::get<0>(get_shape_rank(context, index, false)); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); + auto index_len = context.mark_node(std::make_shared(index_shape, const_0, const_1, const_1)); + + // [A, B, ..., T, ..., K] --> [A, B, ..., len(index), ..., K] + auto target_shape = std::make_shared(input_shape, + dim_vec, + index_len, + v0::Constant::create(element::i32, Shape{}, {0})); + + // broadcast && index fill + auto broadcasted_value = context.mark_node(std::make_shared(value_vec, target_shape, dim_vec)); + auto broadcasted_index = context.mark_node(std::make_shared(index, target_shape, dim_vec)); + auto result = context.mark_node( + std::make_shared(input, broadcasted_index, broadcasted_value, dim)); + + return {result}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 0fd2d8e54006a4..7307833430411f 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -113,6 +113,7 @@ OP_CONVERTER(translate_im2col); OP_CONVERTER(translate_index); OP_CONVERTER(translate_index_add); OP_CONVERTER(translate_index_copy_); +OP_CONVERTER(translate_index_fill_); OP_CONVERTER(translate_index_put_); OP_CONVERTER(translate_index_select); OP_CONVERTER(translate_instance_norm); @@ -496,6 +497,7 @@ const std::unordered_map get_supported_ops_ts() { // aten::imag - Supported in limited set of patterns // aten::index - Supported in limited set of patterns {"aten::index_copy_", op::inplace_op}, + {"aten::index_fill_", op::inplace_op}, {"aten::index_put_", op::inplace_op}, {"aten::index_add", op::translate_index_add}, {"aten::index_select", op::translate_index_select}, diff --git a/tests/layer_tests/pytorch_tests/test_index_fill_.py b/tests/layer_tests/pytorch_tests/test_index_fill_.py new file mode 100644 index 00000000000000..878dda7ab3bd7e --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_index_fill_.py @@ -0,0 +1,83 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestIndexFill(PytorchLayerTest): + def _prepare_input(self): + return (self.input_tensor,) + + def create_model(self, dim, index, values): + class aten_index_fill_(torch.nn.Module): + def __init__(self, dim, index, values): + super().__init__() + self.dim = dim + self.index = index + self.values = values + + def forward(self, input_tensor): + input_tensor.index_fill_(self.dim, self.index, self.values) + return input_tensor + + ref_net = None + + return aten_index_fill_(dim, index, values), ref_net, "aten::index_fill_" + + @pytest.mark.parametrize( + "input_data", + ( + { + "input_shape": [10], + "dim": 0, + "input_value": 5.6, + "index": [5, 6, 7] + }, + { + "input_shape": [3, 3], + "dim": 0, + "input_value": 10.1, + "index": [1, 0] + }, + { + "input_shape": [4, 3, 5], + "dim": 1, + "input_value": 1234.5, + "index": [2, 0] + }, + { + "input_shape": [5, 6, 7, 8], + "dim": -2, + "input_value": 0.1234, + "index": [6, 4, 2, 0] + }, + { + "input_shape": [5, 6, 7, 8], + "dim": -3, + "input_value": -4321234.5678765, + "index": [5, 4, 3, 1] + }, + { + "input_shape": [5, 6, 7, 8], + "dim": 3, + "input_value": -1234.54321, + "index": [6, 4, 7, 2, 1] + }, + ), + ) + @pytest.mark.nightly + @pytest.mark.precommit + def test_index_fill_single_index(self, ie_device, precision, ir_version, input_data): + self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32) + values = torch.tensor(np.float32(input_data["input_value"])) + dim = input_data["dim"] + shape = self.input_tensor.shape + max_idx = shape[dim] + n_select = np.random.randint(1, max_idx + 1) + index = torch.from_numpy(np.random.choice(np.arange(0, max_idx), n_select, replace=False)).to(torch.long) + self._test(*self.create_model(dim, index, values), ie_device, precision, ir_version) \ No newline at end of file