Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: index_fill_ frontend pytorch op #27420

Merged
merged 11 commits into from
Nov 26, 2024
69 changes: 69 additions & 0 deletions src/frontends/pytorch/src/op/index_fill_.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>

#include "openvino/core/node.hpp"
#include "openvino/core/node_output.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_index_fill_(const NodeContext& context) {
// aten::index_fill_(self, dim, index, value) --> Tensor
num_inputs_check(context, 4, 4);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto index = context.get_input(2);
auto value = context.get_input(3);

auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {1});

auto tensor_rank = std::get<1>(get_shape_rank(context, input, false));
auto tensor_rank_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(tensor_rank, dim));
auto dim_vec = normalize_axis(context, dim, tensor_rank_correct_type);

// scalar to vec
auto value_vec = context.mark_node(std::make_shared<v1::Reshape>(value, const_1_vec, false));

auto input_shape = std::get<0>(get_shape_rank(context, input, false));

auto index_shape = std::get<0>(get_shape_rank(context, index, false));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto index_len = context.mark_node(std::make_shared<v8::Slice>(index_shape, const_0, const_1, const_1));

// [A, B, ..., T, ..., K] --> [A, B, ..., len(index), ..., K]
auto target_shape = std::make_shared<v12::ScatterElementsUpdate>(input_shape,
dim_vec,
index_len,
v0::Constant::create(element::i32, Shape{}, {0}));

// broadcast && index fill
auto broadcasted_value = context.mark_node(std::make_shared<v1::Broadcast>(value_vec, target_shape, dim_vec));
auto broadcasted_index = context.mark_node(std::make_shared<v1::Broadcast>(index, target_shape, dim_vec));
auto result = context.mark_node(
std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, broadcasted_value, dim));

return {result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ OP_CONVERTER(translate_im2col);
OP_CONVERTER(translate_index);
OP_CONVERTER(translate_index_add);
OP_CONVERTER(translate_index_copy_);
OP_CONVERTER(translate_index_fill_);
OP_CONVERTER(translate_index_put_);
OP_CONVERTER(translate_index_select);
OP_CONVERTER(translate_instance_norm);
Expand Down Expand Up @@ -496,6 +497,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
// aten::imag - Supported in limited set of patterns
// aten::index - Supported in limited set of patterns
{"aten::index_copy_", op::inplace_op<op::translate_index_copy_>},
{"aten::index_fill_", op::inplace_op<op::translate_index_fill_>},
{"aten::index_put_", op::inplace_op<op::translate_index_put_>},
{"aten::index_add", op::translate_index_add},
{"aten::index_select", op::translate_index_select},
Expand Down
83 changes: 83 additions & 0 deletions tests/layer_tests/pytorch_tests/test_index_fill_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class TestIndexFill(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)

def create_model(self, dim, index, values):
class aten_index_fill_(torch.nn.Module):
def __init__(self, dim, index, values):
super().__init__()
self.dim = dim
self.index = index
self.values = values

def forward(self, input_tensor):
input_tensor.index_fill_(self.dim, self.index, self.values)
return input_tensor

ref_net = None

return aten_index_fill_(dim, index, values), ref_net, "aten::index_fill_"

@pytest.mark.parametrize(
"input_data",
(
{
"input_shape": [10],
"dim": 0,
"input_value": 5.6,
"index": [5, 6, 7]
},
{
"input_shape": [3, 3],
"dim": 0,
"input_value": 10.1,
"index": [1, 0]
},
{
"input_shape": [4, 3, 5],
"dim": 1,
"input_value": 1234.5,
"index": [2, 0]
},
{
"input_shape": [5, 6, 7, 8],
"dim": -2,
"input_value": 0.1234,
"index": [6, 4, 2, 0]
},
{
"input_shape": [5, 6, 7, 8],
"dim": -3,
"input_value": -4321234.5678765,
"index": [5, 4, 3, 1]
},
{
"input_shape": [5, 6, 7, 8],
"dim": 3,
"input_value": -1234.54321,
"index": [6, 4, 7, 2, 1]
},
),
)
@pytest.mark.nightly
@pytest.mark.precommit
def test_index_fill_single_index(self, ie_device, precision, ir_version, input_data):
self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32)
values = torch.tensor(np.float32(input_data["input_value"]))
dim = input_data["dim"]
shape = self.input_tensor.shape
max_idx = shape[dim]
n_select = np.random.randint(1, max_idx + 1)
index = torch.from_numpy(np.random.choice(np.arange(0, max_idx), n_select, replace=False)).to(torch.long)
self._test(*self.create_model(dim, index, values), ie_device, precision, ir_version)
Loading