-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add linspace * Add linspace tests * Cleanup * Format schema
- Loading branch information
1 parent
4ad072e
commit a22c5f1
Showing
3 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright (C) 2018-2023 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/add.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "openvino/op/convert_like.hpp" | ||
#include "openvino/op/divide.hpp" | ||
#include "openvino/op/equal.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/range.hpp" | ||
#include "openvino/op/select.hpp" | ||
#include "openvino/op/subtract.hpp" | ||
#include "pt_framework_node.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_linspace(const NodeContext& context) { | ||
num_inputs_check(context, 3, 7); | ||
// "aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? | ||
// device=None, bool? pin_memory=None) -> Tensor" | ||
|
||
// "aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)" | ||
auto start = context.mark_node(std::make_shared<v0::Convert>(context.get_input(0), element::f32)); | ||
auto end = context.mark_node(std::make_shared<v0::Convert>(context.get_input(1), element::f32)); | ||
auto steps = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), element::f32)); | ||
auto out_tensor = context.get_input(1); | ||
auto apply_dtype = true; | ||
auto dtype = element::f32; | ||
if (!context.input_is_none(3) && context.get_input_size() == 7) { | ||
// Case where dtype is provided directly in dtype input. | ||
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(3).get_node_shared_ptr())) { | ||
dtype = convert_dtype(context.const_input<int64_t>(3)); | ||
apply_dtype = true; | ||
} else if (const auto& fw_node = cast_fw_node(context.get_input(3).get_node_shared_ptr(), "prim::dtype")) { | ||
out_tensor = fw_node->input_value(0); | ||
apply_dtype = false; | ||
} else { | ||
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input"); | ||
} | ||
} else if (!context.input_is_none(3) && context.get_input_size() == 4) { | ||
// Case where dtype is inherited from out tensor. | ||
out_tensor = context.get_input(3); | ||
apply_dtype = false; | ||
} | ||
|
||
auto const_0 = v0::Constant::create(element::f32, Shape{}, {0}); | ||
auto const_1 = v0::Constant::create(element::f32, Shape{}, {1}); | ||
auto step_range = context.mark_node(std::make_shared<v4::Range>(const_0, steps, const_1, element::f32)); | ||
|
||
auto sub_end_start = context.mark_node(std::make_shared<v1::Subtract>(end, start)); | ||
auto sub_steps_1 = context.mark_node(std::make_shared<v1::Subtract>(steps, const_1)); | ||
auto step_multiplier = context.mark_node(std::make_shared<v1::Divide>(sub_end_start, sub_steps_1)); | ||
auto is_single_step = context.mark_node(std::make_shared<v1::Equal>(steps, const_1)); | ||
auto select_multiplier = context.mark_node(std::make_shared<v1::Select>(is_single_step, const_0, step_multiplier)); | ||
auto step_values = context.mark_node(std::make_shared<v1::Multiply>(step_range, select_multiplier)); | ||
|
||
auto linspace = context.mark_node(std::make_shared<v1::Add>(step_values, start)); | ||
if (apply_dtype) { | ||
linspace = context.mark_node(std::make_shared<v0::Convert>(linspace, dtype)); | ||
} else { | ||
linspace = context.mark_node(std::make_shared<v1::ConvertLike>(linspace, out_tensor)); | ||
} | ||
|
||
return {linspace}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright (C) 2018-2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
|
||
class TestLinspace(PytorchLayerTest): | ||
def _prepare_input(self, start, end, steps, dtype=None, ref_dtype=None): | ||
inputs = [np.array(start).astype(dtype), np.array(end).astype(dtype), np.array(steps).astype("int32")] | ||
if ref_dtype: | ||
inputs.append(np.zeros(1).astype(ref_dtype)) | ||
return inputs | ||
|
||
def create_model(self, dtype=None, use_out=False, ref_dtype=False): | ||
dtype_map = { | ||
"float32": torch.float32, | ||
"float64": torch.float64, | ||
"int64": torch.int64, | ||
"int32": torch.int32, | ||
"uint8": torch.uint8, | ||
"int8": torch.int8, | ||
} | ||
|
||
class aten_linspace_dtype(torch.nn.Module): | ||
def __init__(self, dtype) -> None: | ||
super().__init__() | ||
self.dtype = dtype | ||
|
||
def forward(self, start, end, steps): | ||
return torch.linspace(start=start, end=end, steps=steps, dtype=self.dtype) | ||
|
||
class aten_linspace_out(torch.nn.Module): | ||
def __init__(self, out) -> None: | ||
super().__init__() | ||
# Size of empty tensor needs to be of equal or larger size than linspace steps | ||
self.out = torch.empty(25, dtype=out) | ||
|
||
def forward(self, start, end, steps): | ||
return torch.linspace(start=start, end=end, steps=steps, out=self.out) | ||
|
||
class aten_linspace_prim_dtype(torch.nn.Module): | ||
def forward(self, start, end, steps, d): | ||
return torch.linspace(start=start, end=end, steps=steps, dtype=d.dtype) | ||
|
||
dtype = dtype_map.get(dtype) | ||
if ref_dtype: | ||
model_class = aten_linspace_prim_dtype() | ||
elif not use_out: | ||
model_class = aten_linspace_dtype(dtype) | ||
else: | ||
model_class = aten_linspace_out(dtype) | ||
|
||
ref_net = None | ||
|
||
return model_class, ref_net, "aten::linspace" | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"]) | ||
@pytest.mark.parametrize( | ||
"start,end,steps", [(0, 1, 5), (-2, 1, 5), (1, -5, 7), (1, 10, 2), (-1, -5, 2), (-1, -5, 1), (1.25, -5.5, 5)] | ||
) | ||
def test_linspace_with_prim_dtype(self, dtype, end, start, steps, ie_device, precision, ir_version): | ||
self._test( | ||
*self.create_model(dtype, ref_dtype=True), | ||
ie_device, | ||
precision, | ||
ir_version, | ||
kwargs_to_prepare_input={"end": end, "start": start, "steps": steps, "ref_dtype": dtype} | ||
) | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"]) | ||
@pytest.mark.parametrize( | ||
"start,end,steps", [(0, 1, 5), (-2, 1, 5), (1, -5, 7), (1, 10, 2), (-1, -5, 2), (-1, -5, 1), (1.25, -5.5, 5)] | ||
) | ||
@pytest.mark.parametrize("use_out", [False, True]) | ||
def test_linspace_with_out(self, dtype, use_out, end, start, steps, ie_device, precision, ir_version): | ||
self._test( | ||
*self.create_model(dtype=dtype, use_out=use_out), | ||
ie_device, | ||
precision, | ||
ir_version, | ||
kwargs_to_prepare_input={"end": end, "start": start, "steps": steps} | ||
) |