-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Index fill pytorch #26488
Closed
Closed
Index fill pytorch #26488
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
18ba7ba
Added index_fill for pytorch frontend and tests
muhd360 bf2e3f6
index_fill method
muhd360 55c2ae3
Update src/frontends/pytorch/src/op/index_fill_.cpp
muhd360 a988105
updated the conversion function in the list of operators
muhd360 a946f5c
Merge branch 'index_fill-pytorch' of https://github.com/muhd360/openv…
muhd360 ea8e6f5
op convertor
muhd360 ee1cf0b
Merge branch 'master' into index_fill-pytorch
muhd360 c58b776
clang-format
muhd360 ae157c6
Merge branch 'index_fill-pytorch' of https://github.com/muhd360/openv…
muhd360 957a523
Merge branch 'master' into index_fill-pytorch
muhd360 3a06621
CODE STYLE
muhd360 ec6de6f
Merge branch 'index_fill-pytorch' of https://github.com/muhd360/openv…
muhd360 a11712c
Merge branch 'master' into index_fill-pytorch
mlukasze 18771c9
node
muhd360 06a01d4
Merge branch 'index_fill-pytorch' of https://github.com/muhd360/openv…
muhd360 b61e5aa
Merge branch 'openvinotoolkit:master' into index_fill-pytorch
muhd360 1f8dbf4
format
muhd360 4acd3ae
Merge branch 'index_fill-pytorch' of https://github.com/muhd360/openv…
muhd360 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <memory> | ||
|
||
#include "openvino/core/node.hpp" | ||
#include "openvino/core/node_output.hpp" | ||
#include "openvino/core/type/element_type.hpp" | ||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/broadcast.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert_like.hpp" | ||
#include "openvino/op/reshape.hpp" | ||
#include "openvino/op/scatter_elements_update.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
#include "utils.hpp" | ||
|
||
using namespace ov::op; | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
OutputVector translate_index_fill_(const NodeContext& context) { | ||
// aten::index_fill_(self, dim, index, value) → Tensor | ||
num_inputs_check(context, 4, 4); | ||
auto input = context.get_input(0); | ||
auto dim = context.get_input(1); | ||
auto index = context.get_input(2); | ||
auto value = context.get_input(3); | ||
|
||
auto const_1_vec = v0::Constant::create(element::i32, Shape{1}, {value}); | ||
|
||
Output<Node> tensor_rank = std::get<1>(get_shape_rank(context, input, true)); | ||
auto tensor_rank_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(tensor_rank, dim)); | ||
auto positive_dim = normalize_axis(context, dim, tensor_rank_correct_type); | ||
|
||
// begin the computation | ||
//indx_cpy(dim,idx_def,tensor_def) | ||
//indx_fill(dim,idx_def,val) | ||
auto tensor_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32)); | ||
auto dim_vec = context.mark_node(std::make_shared<v1::Reshape>(positive_dim, const_1_vec, false)); | ||
auto broadcasted_index = context.mark_node(std::make_shared<v1::Broadcast>(index, tensor_shape, dim_vec)); | ||
//reshaping steps | ||
|
||
//so we need to create a tensor with the same shape as the input tensor | ||
//(val tensor is broadcasted to the shape of the input tensor) then index_copy | ||
//val tensor->tensor of shape of input tensor with all values v | ||
|
||
|
||
|
||
|
||
auto result = | ||
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input, broadcasted_index, index, dim)); | ||
return {result}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestIndexFill(PytorchLayerTest): | ||
def _prepare_input(self): | ||
return (self.input_tensor, self.values) | ||
|
||
def create_model(self, dim, index): | ||
class aten_index_fill_(torch.nn.Module): | ||
def __init__(self, dim, index): | ||
super().__init__() | ||
self.dim = dim | ||
self.index = index | ||
|
||
def forward(self, input_tensor, values): | ||
input_tensor.index_fill_(self.dim, self.index, values) | ||
return input_tensor | ||
|
||
ref_net = None | ||
|
||
return aten_index_fill_(dim, index), ref_net, "aten::index_fill_" | ||
|
||
@pytest.mark.parametrize( | ||
"input_data", | ||
( | ||
{ | ||
"input_shape": [1], | ||
"dim": 0, | ||
"values_shape": [1], | ||
"index": torch.tensor([0], dtype=torch.long) | ||
}, | ||
{ | ||
"input_shape": [10], | ||
"dim": 0, | ||
"values_shape": [5], | ||
"index": torch.tensor([2, 3, 6, 7, 1], dtype=torch.long) | ||
}, | ||
{ | ||
"input_shape": [3, 3], | ||
"dim": 0, | ||
"values_shape": [2, 3], | ||
"index": torch.tensor([2, 0], dtype=torch.long) | ||
}, | ||
{ | ||
"input_shape": [4, 3, 5], | ||
"dim": 1, | ||
"values_shape": [4, 2, 5], | ||
"index": torch.tensor([1, 0], dtype=torch.long) | ||
}, | ||
{ | ||
"input_shape": [5, 6, 7, 8], | ||
"dim": -2, | ||
"values_shape": [5, 6, 4, 8], | ||
"index": torch.tensor([5, 0, 6, 3], dtype=torch.long) | ||
}, | ||
{ | ||
"input_shape": [5, 6, 7, 8], | ||
"dim": -3, | ||
"values_shape": [5, 3, 7, 8], | ||
"index": torch.tensor([2, 0, 1], dtype=torch.long) | ||
}, | ||
{ | ||
"input_shape": [5, 6, 7, 8], | ||
"dim": 3, | ||
"values_shape": [5, 6, 7, 5], | ||
"index": torch.tensor([2, 6, 0, 4, 1], dtype=torch.long) | ||
}, | ||
), | ||
) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_index_copy_single_index(self, ie_device, precision, ir_version, input_data): | ||
self.input_tensor = np.random.randn(*input_data["input_shape"]).astype(np.float32) | ||
self.values = np.random.randn(*input_data["values_shape"]).astype(np.float32) | ||
index = input_data["index"] | ||
dim = input_data["dim"] | ||
self._test(*self.create_model(dim, index), ie_device, precision, ir_version) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value
is Node, you are trying to create a constant from it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok will fix