Skip to content

Commit

Permalink
Variadic Split (#91)
Browse files Browse the repository at this point in the history
* variadic split w/o -1

* -1 implemented

* Fix -1 support in variadic split and test it

* Style consistency

* Cleanup

* Fix exception message
  • Loading branch information
mwyi authored Nov 2, 2020
1 parent 3aa9872 commit 7048f75
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 8 deletions.
72 changes: 72 additions & 0 deletions inference-engine/src/plaidml_plugin/ops/variadic_split.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "plaidml_ops.hpp"
#include "plaidml_util.hpp"

#include "ngraph/opsets/opset.hpp"
#include "ngraph/opsets/opset1.hpp"

#include "plaidml/op/op.h"

using namespace plaidml; // NOLINT[build/namespaces]
using namespace InferenceEngine; // NOLINT[build/namespaces]

namespace {

template <typename T>
std::vector<T> cast_constant_operand(size_t operand_idx, ngraph::Node* layer) {
auto* ngraph_const = ngraph::as_type<ngraph::op::Constant>(layer->get_input_node_ptr(operand_idx));
if (ngraph_const) {
return ngraph_const->cast_vector<T>();
} else {
THROW_IE_EXCEPTION << "Dynamic split lengths not currently supported by PlaidML plugin; all of split_lengths "
"must be Constants";
}
}

} // namespace

namespace PlaidMLPlugin {

static OpRegistration reg("variadicsplit", [](const Context& ctx) {
IE_ASSERT(ctx.operands.size() == 3);
auto I = ctx.operands.at(0);
auto axes = get_axis_vector_from_constant_operand(1, ctx.layer);
IE_ASSERT(axes.size() == 1);
auto axis = axes[0];
auto split_lengths = cast_constant_operand<int32_t>(2, ctx.layer);

auto ndims = I.rank();
std::vector<edsl::TensorDim> I_dims(ndims);
std::vector<edsl::TensorIndex> I_idxs(ndims);
std::vector<edsl::Tensor> Os;
I.bind_dims(I_dims);
auto O_dims = I_dims;

size_t split_size = 0;
for (auto split : split_lengths) {
if (split != -1) {
split_size += split;
}
}
auto placeholder = I_dims[axis] - split_size;

edsl::TensorDim offset(0);
for (auto split : split_lengths) {
auto O_idxs = I_idxs;
O_idxs[axis] = I_idxs[axis] - offset;
if (split == -1) {
O_dims[axis] = placeholder;
offset = offset + placeholder;
} else {
O_dims[axis] = edsl::TensorDim(split);
offset = offset + split;
}
Os.push_back(edsl::Contraction().outShape(O_dims).outAccess(O_idxs).assign(I(I_idxs)));
}
return edsl::make_tuple(Os);
});

} // namespace PlaidMLPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ namespace {
};

// Sum of elements numSplits = inputShapes[Axis]
const std::vector<std::vector<size_t>> numSplits = {
const std::vector<std::vector<int32_t>> numSplits = {
{1, 16, 5, 8},
{2, 19, 5, 4},
{7, 13, 2, 8},
{5, 8, 12, 5},
{4, 11, 6, 9}
{4, 11, -1, 9}
};

INSTANTIATE_TEST_CASE_P(NumSplitsCheck, VariadicSplitLayerTest,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <vector>

#include "single_layer_tests/variadic_split.hpp"
#include "common_test_utils/test_constants.hpp"

using namespace LayerTestsDefinitions;

namespace {

const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
// InferenceEngine::Precision::FP16
};

// Sum of elements numSplits = inputShapes[Axis]
const std::vector<std::vector<int32_t>> numSplits = {
{1, 16, 5, 8},
{2, 19, 5, 4},
{7, 13, 2, 8},
{5, 8, 12, 5},
{4, 11, -1, 9}
};

INSTANTIATE_TEST_CASE_P(NumSplitsCheck, VariadicSplitLayerTest,
::testing::Combine(
::testing::ValuesIn(numSplits),
::testing::Values(0, 1, 2, 3),
::testing::ValuesIn(netPrecisions),
::testing::Values(std::vector<size_t>({30, 30, 30, 30})),
::testing::Values(CommonTestUtils::DEVICE_PLAIDML)),
VariadicSplitLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
namespace LayerTestsDefinitions {

typedef std::tuple<
std::vector<size_t>, // Num splits
std::vector<int32_t>, // Num splits
size_t, // Axis
InferenceEngine::Precision, // Net precision
std::vector<size_t>, // Input shapes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace LayerTestsDefinitions {

std::string VariadicSplitLayerTest::getTestCaseName(testing::TestParamInfo<VariadicSplitParams> obj) {
size_t axis;
std::vector<size_t> numSplits;
std::vector<int32_t> numSplits;
InferenceEngine::Precision netPrecision;
InferenceEngine::SizeVector inputShapes;
std::string targetDevice;
Expand All @@ -41,7 +41,8 @@ namespace LayerTestsDefinitions {
void VariadicSplitLayerTest::SetUp() {
SetRefMode(LayerTestsUtils::RefMode::CONSTANT_FOLDING);
size_t axis;
std::vector<size_t> inputShape, numSplits;
std::vector<size_t> inputShape;
std::vector<int32_t> numSplits;
InferenceEngine::Precision netPrecision;
std::tie(numSplits, axis, netPrecision, inputShape, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ std::shared_ptr<ngraph::Node> makeSplit(const ngraph::Output<Node> &in,
size_t axis);

std::shared_ptr<ngraph::Node> makeVariadicSplit(const ngraph::Output<Node> &in,
const std::vector<size_t> numSplits,
const std::vector<int32_t> numSplits,
size_t axis);

std::shared_ptr<ngraph::Node> makeActivation(const ngraph::Output<Node> &in,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeVariadicSplit(const ngraph::Output<Node> &in,
const std::vector<size_t> numSplits,
const std::vector<int32_t> numSplits,
size_t axis) {
auto splitAxisOp = std::make_shared<ngraph::opset3::Constant>(element::u64, ngraph::Shape{},
std::vector<size_t>{axis});
auto numSplit = std::make_shared<ngraph::opset3::Constant>(element::u64, ngraph::Shape{numSplits.size()},
auto numSplit = std::make_shared<ngraph::opset3::Constant>(element::i64, ngraph::Shape{numSplits.size()},
numSplits);
auto VariadicSplitNode = std::make_shared<ngraph::opset3::VariadicSplit>(in, splitAxisOp, numSplit);
return VariadicSplitNode;
Expand Down

0 comments on commit 7048f75

Please sign in to comment.