Skip to content

Commit

Permalink
[PT FE] Add aten::linspace (#18998)
Browse files Browse the repository at this point in the history
* Add linspace

* Add linspace tests

* Cleanup

* Format schema
  • Loading branch information
mmikolajcz authored Aug 7, 2023
1 parent 4ad072e commit a22c5f1
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
79 changes: 79 additions & 0 deletions src/frontends/pytorch/src/op/linspace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/subtract.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_linspace(const NodeContext& context) {
num_inputs_check(context, 3, 7);
// "aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device?
// device=None, bool? pin_memory=None) -> Tensor"

// "aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)"
auto start = context.mark_node(std::make_shared<v0::Convert>(context.get_input(0), element::f32));
auto end = context.mark_node(std::make_shared<v0::Convert>(context.get_input(1), element::f32));
auto steps = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), element::f32));
auto out_tensor = context.get_input(1);
auto apply_dtype = true;
auto dtype = element::f32;
if (!context.input_is_none(3) && context.get_input_size() == 7) {
// Case where dtype is provided directly in dtype input.
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(3).get_node_shared_ptr())) {
dtype = convert_dtype(context.const_input<int64_t>(3));
apply_dtype = true;
} else if (const auto& fw_node = cast_fw_node(context.get_input(3).get_node_shared_ptr(), "prim::dtype")) {
out_tensor = fw_node->input_value(0);
apply_dtype = false;
} else {
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
}
} else if (!context.input_is_none(3) && context.get_input_size() == 4) {
// Case where dtype is inherited from out tensor.
out_tensor = context.get_input(3);
apply_dtype = false;
}

auto const_0 = v0::Constant::create(element::f32, Shape{}, {0});
auto const_1 = v0::Constant::create(element::f32, Shape{}, {1});
auto step_range = context.mark_node(std::make_shared<v4::Range>(const_0, steps, const_1, element::f32));

auto sub_end_start = context.mark_node(std::make_shared<v1::Subtract>(end, start));
auto sub_steps_1 = context.mark_node(std::make_shared<v1::Subtract>(steps, const_1));
auto step_multiplier = context.mark_node(std::make_shared<v1::Divide>(sub_end_start, sub_steps_1));
auto is_single_step = context.mark_node(std::make_shared<v1::Equal>(steps, const_1));
auto select_multiplier = context.mark_node(std::make_shared<v1::Select>(is_single_step, const_0, step_multiplier));
auto step_values = context.mark_node(std::make_shared<v1::Multiply>(step_range, select_multiplier));

auto linspace = context.mark_node(std::make_shared<v1::Add>(step_values, start));
if (apply_dtype) {
linspace = context.mark_node(std::make_shared<v0::Convert>(linspace, dtype));
} else {
linspace = context.mark_node(std::make_shared<v1::ConvertLike>(linspace, out_tensor));
}

return {linspace};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ OP_CONVERTER(translate_linalg_norm);
OP_CONVERTER(translate_linalg_matrix_norm);
OP_CONVERTER(translate_linalg_vector_norm);
OP_CONVERTER(translate_linear);
OP_CONVERTER(translate_linspace);
OP_CONVERTER(translate_list_construct);
OP_CONVERTER(translate_list_unpack);
OP_CONVERTER(translate_log);
Expand Down Expand Up @@ -331,6 +332,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::linalg_matrix_norm", op::translate_linalg_matrix_norm},
{"aten::linalg_vector_norm", op::translate_linalg_vector_norm},
{"aten::linear", op::translate_linear},
{"aten::linspace", op::translate_linspace},
{"aten::log", op::translate_log},
{"aten::log_", op::inplace_op<op::translate_log>},
{"aten::log_softmax", op::translate_log_softmax},
Expand Down
89 changes: 89 additions & 0 deletions tests/layer_tests/pytorch_tests/test_linspace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest


class TestLinspace(PytorchLayerTest):
def _prepare_input(self, start, end, steps, dtype=None, ref_dtype=None):
inputs = [np.array(start).astype(dtype), np.array(end).astype(dtype), np.array(steps).astype("int32")]
if ref_dtype:
inputs.append(np.zeros(1).astype(ref_dtype))
return inputs

def create_model(self, dtype=None, use_out=False, ref_dtype=False):
dtype_map = {
"float32": torch.float32,
"float64": torch.float64,
"int64": torch.int64,
"int32": torch.int32,
"uint8": torch.uint8,
"int8": torch.int8,
}

class aten_linspace_dtype(torch.nn.Module):
def __init__(self, dtype) -> None:
super().__init__()
self.dtype = dtype

def forward(self, start, end, steps):
return torch.linspace(start=start, end=end, steps=steps, dtype=self.dtype)

class aten_linspace_out(torch.nn.Module):
def __init__(self, out) -> None:
super().__init__()
# Size of empty tensor needs to be of equal or larger size than linspace steps
self.out = torch.empty(25, dtype=out)

def forward(self, start, end, steps):
return torch.linspace(start=start, end=end, steps=steps, out=self.out)

class aten_linspace_prim_dtype(torch.nn.Module):
def forward(self, start, end, steps, d):
return torch.linspace(start=start, end=end, steps=steps, dtype=d.dtype)

dtype = dtype_map.get(dtype)
if ref_dtype:
model_class = aten_linspace_prim_dtype()
elif not use_out:
model_class = aten_linspace_dtype(dtype)
else:
model_class = aten_linspace_out(dtype)

ref_net = None

return model_class, ref_net, "aten::linspace"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "int8"])
@pytest.mark.parametrize(
"start,end,steps", [(0, 1, 5), (-2, 1, 5), (1, -5, 7), (1, 10, 2), (-1, -5, 2), (-1, -5, 1), (1.25, -5.5, 5)]
)
def test_linspace_with_prim_dtype(self, dtype, end, start, steps, ie_device, precision, ir_version):
self._test(
*self.create_model(dtype, ref_dtype=True),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"end": end, "start": start, "steps": steps, "ref_dtype": dtype}
)

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64", "int8", "uin8"])
@pytest.mark.parametrize(
"start,end,steps", [(0, 1, 5), (-2, 1, 5), (1, -5, 7), (1, 10, 2), (-1, -5, 2), (-1, -5, 1), (1.25, -5.5, 5)]
)
@pytest.mark.parametrize("use_out", [False, True])
def test_linspace_with_out(self, dtype, use_out, end, start, steps, ie_device, precision, ir_version):
self._test(
*self.create_model(dtype=dtype, use_out=use_out),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"end": end, "start": start, "steps": steps}
)

0 comments on commit a22c5f1

Please sign in to comment.