Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index fill pytorch #26488

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/frontends/pytorch/src/op/index_fill_.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>

#include "openvino/core/node.hpp"
#include "openvino/core/node_output.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "utils.hpp"

using namespace ov::op;

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

OutputVector translate_index_copy_(const NodeContext& context) {
muhd360 marked this conversation as resolved.
Show resolved Hide resolved
// aten::index_fill_(self, dim, index, value) → Tensor
num_inputs_check(context, 4, 4);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto index = context.get_input(2);
auto value = context.get_input(3);

auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {value});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value is Node, you are trying to create a constant from it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will fix


Output<Node> tensor_rank = std::get<1>(get_shape_rank(context, input, true));
auto tensor_rank_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(tensor_rank, dim));
auto positive_dim = normalize_axis(context, dim, tensor_rank_correct_type);

// begin the computation
//indx_cpy(dim,idx_def,tensor_def)
//indx_fill(dim,idx_def,val)
auto tensor_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto dim_vec = context.mark_node(std::make_shared<v1::Reshape>(positive_dim, const_1_vec, false));
auto broadcasted_index = context.mark_node(std::make_shared<v1::Broadcast>(index, tensor_shape, dim_vec));
//reshaping steps

//so we need to create a tensor with the same shape as the input tensor
//(val tensor is broadcasted to the shape of the input tensor) then index_copy
//val tensor->tensor of shape of input tensor with all values v




auto result =
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, index, dim));
return {result};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
85 changes: 85 additions & 0 deletions tests/layer_tests/pytorch_tests/test_index_fill_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest


class TestIndexFill(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor, self.values)

def create_model(self, dim, index):
class aten_index_fill_(torch.nn.Module):
def __init__(self, dim, index):
super().__init__()
self.dim = dim
self.index = index

def forward(self, input_tensor, values):
input_tensor.index_fill_(self.dim, self.index, values)
return input_tensor

ref_net = None

return aten_index_fill_(dim, index), ref_net, "aten::index_fill_"

@pytest.mark.parametrize(
"input_data",
(
{
"input_shape": [1],
"dim": 0,
"values_shape": [1],
"index": torch.tensor([0], dtype=torch.long)
},
{
"input_shape": [10],
"dim": 0,
"values_shape": [5],
"index": torch.tensor([2, 3, 6, 7, 1], dtype=torch.long)
},
{
"input_shape": [3, 3],
"dim": 0,
"values_shape": [2, 3],
"index": torch.tensor([2, 0], dtype=torch.long)
},
{
"input_shape": [4, 3, 5],
"dim": 1,
"values_shape": [4, 2, 5],
"index": torch.tensor([1, 0], dtype=torch.long)
},
{
"input_shape": [5, 6, 7, 8],
"dim": -2,
"values_shape": [5, 6, 4, 8],
"index": torch.tensor([5, 0, 6, 3], dtype=torch.long)
},
{
"input_shape": [5, 6, 7, 8],
"dim": -3,
"values_shape": [5, 3, 7, 8],
"index": torch.tensor([2, 0, 1], dtype=torch.long)
},
{
"input_shape": [5, 6, 7, 8],
"dim": 3,
"values_shape": [5, 6, 7, 5],
"index": torch.tensor([2, 6, 0, 4, 1], dtype=torch.long)
},
),
)
@pytest.mark.nightly
@pytest.mark.precommit
def test_index_copy_single_index(self, ie_device, precision, ir_version, input_data):
self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32)
self.values = np.random.randn(*input_data["values_shape"]).astype(np.float32)
index = input_data["index"]
dim = input_data["dim"]
self._test(*self.create_model(dim, index), ie_device, precision, ir_version)
Loading