forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* variadic split w/o -1 * -1 implemented * Fix -1 support in variadic split and test it * Style consistency * Cleanup * Fix exception message
- Loading branch information
Showing
7 changed files
with
117 additions
and
8 deletions.
There are no files selected for viewing
72 changes: 72 additions & 0 deletions
72
inference-engine/src/plaidml_plugin/ops/variadic_split.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "plaidml_ops.hpp" | ||
#include "plaidml_util.hpp" | ||
|
||
#include "ngraph/opsets/opset.hpp" | ||
#include "ngraph/opsets/opset1.hpp" | ||
|
||
#include "plaidml/op/op.h" | ||
|
||
using namespace plaidml; // NOLINT[build/namespaces] | ||
using namespace InferenceEngine; // NOLINT[build/namespaces] | ||
|
||
namespace { | ||
|
||
template <typename T> | ||
std::vector<T> cast_constant_operand(size_t operand_idx, ngraph::Node* layer) { | ||
auto* ngraph_const = ngraph::as_type<ngraph::op::Constant>(layer->get_input_node_ptr(operand_idx)); | ||
if (ngraph_const) { | ||
return ngraph_const->cast_vector<T>(); | ||
} else { | ||
THROW_IE_EXCEPTION << "Dynamic split lengths not currently supported by PlaidML plugin; all of split_lengths " | ||
"must be Constants"; | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
namespace PlaidMLPlugin { | ||
|
||
static OpRegistration reg("variadicsplit", [](const Context& ctx) { | ||
IE_ASSERT(ctx.operands.size() == 3); | ||
auto I = ctx.operands.at(0); | ||
auto axes = get_axis_vector_from_constant_operand(1, ctx.layer); | ||
IE_ASSERT(axes.size() == 1); | ||
auto axis = axes[0]; | ||
auto split_lengths = cast_constant_operand<int32_t>(2, ctx.layer); | ||
|
||
auto ndims = I.rank(); | ||
std::vector<edsl::TensorDim> I_dims(ndims); | ||
std::vector<edsl::TensorIndex> I_idxs(ndims); | ||
std::vector<edsl::Tensor> Os; | ||
I.bind_dims(I_dims); | ||
auto O_dims = I_dims; | ||
|
||
size_t split_size = 0; | ||
for (auto split : split_lengths) { | ||
if (split != -1) { | ||
split_size += split; | ||
} | ||
} | ||
auto placeholder = I_dims[axis] - split_size; | ||
|
||
edsl::TensorDim offset(0); | ||
for (auto split : split_lengths) { | ||
auto O_idxs = I_idxs; | ||
O_idxs[axis] = I_idxs[axis] - offset; | ||
if (split == -1) { | ||
O_dims[axis] = placeholder; | ||
offset = offset + placeholder; | ||
} else { | ||
O_dims[axis] = edsl::TensorDim(split); | ||
offset = offset + split; | ||
} | ||
Os.push_back(edsl::Contraction().outShape(O_dims).outAccess(O_idxs).assign(I(I_idxs))); | ||
} | ||
return edsl::make_tuple(Os); | ||
}); | ||
|
||
} // namespace PlaidMLPlugin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
...ts/functional/plugin/plaidml/shared_tests_instances/single_layer_tests/variadic_split.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (C) 2019 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include <vector> | ||
|
||
#include "single_layer_tests/variadic_split.hpp" | ||
#include "common_test_utils/test_constants.hpp" | ||
|
||
using namespace LayerTestsDefinitions; | ||
|
||
namespace { | ||
|
||
const std::vector<InferenceEngine::Precision> netPrecisions = { | ||
InferenceEngine::Precision::FP32, | ||
// InferenceEngine::Precision::FP16 | ||
}; | ||
|
||
// Sum of elements numSplits = inputShapes[Axis] | ||
const std::vector<std::vector<int32_t>> numSplits = { | ||
{1, 16, 5, 8}, | ||
{2, 19, 5, 4}, | ||
{7, 13, 2, 8}, | ||
{5, 8, 12, 5}, | ||
{4, 11, -1, 9} | ||
}; | ||
|
||
INSTANTIATE_TEST_CASE_P(NumSplitsCheck, VariadicSplitLayerTest, | ||
::testing::Combine( | ||
::testing::ValuesIn(numSplits), | ||
::testing::Values(0, 1, 2, 3), | ||
::testing::ValuesIn(netPrecisions), | ||
::testing::Values(std::vector<size_t>({30, 30, 30, 30})), | ||
::testing::Values(CommonTestUtils::DEVICE_PLAIDML)), | ||
VariadicSplitLayerTest::getTestCaseName); | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters