From 7bf3e607cdea86fefa88532e34b29c71ed817544 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 6 Sep 2023 11:10:37 +0200 Subject: [PATCH] Support aten::tile op --- src/frontends/pytorch/src/op_table.cpp | 1 + tests/layer_tests/pytorch_tests/test_tile.py | 33 ++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 tests/layer_tests/pytorch_tests/test_tile.py diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 244b8fd804a1e1..67bee51301868e 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -435,6 +435,7 @@ const std::map get_supported_ops_ts() { {"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::tanh_", op::inplace_op>}, {"aten::tensor", op::translate_as_tensor}, + {"aten::tile", op::translate_1to1_match_2_inputs}, {"aten::to", op::translate_to}, {"aten::topk", op::translate_topk}, {"aten::transpose", op::quantizable_op}, diff --git a/tests/layer_tests/pytorch_tests/test_tile.py b/tests/layer_tests/pytorch_tests/test_tile.py new file mode 100644 index 00000000000000..d0223b95a7147c --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_tile.py @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestTile(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self, dims): + import torch + + class aten_tile(torch.nn.Module): + def __init__(self, dims): + super(aten_tile, self).__init__() + self.dims = dims + + def forward(self, x): + return torch.tile(x, self.dims) + + ref_net = None + + return aten_tile(dims), ref_net, "aten::tile" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("dims", [(2, 2), (1, 1), (1, 2, 3, 4)]) + def test_tile(self, dims, ie_device, precision, ir_version): + self._test(*self.create_model(dims), ie_device, precision, ir_version)