Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nGraph implementation of NMS-5 (without evaluate()) #2651

Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
066daac
Commit.
vgavrilo Jul 30, 2020
c212c50
Merge remote-tracking branch 'upstream/master'
vgavrilo Jul 30, 2020
228af66
Merge remote-tracking branch 'upstream/master'
vgavrilo Aug 15, 2020
66b99ac
Merge remote-tracking branch 'upstream/master'
vgavrilo Aug 20, 2020
62b2452
Merge remote-tracking branch 'upstream/master'
vgavrilo Aug 26, 2020
dd5a343
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 1, 2020
146cfcb
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 2, 2020
11cfd32
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 2, 2020
f135fdc
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 4, 2020
14b3b49
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 11, 2020
29d798a
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 16, 2020
5aa69a3
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 16, 2020
3754f40
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 17, 2020
a211ce8
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 18, 2020
e7ae609
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 21, 2020
2ed2d5c
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 21, 2020
bdbfb81
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 25, 2020
29cfcfc
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 28, 2020
e64b285
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 29, 2020
ebf97c4
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 6, 2020
42de39d
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 13, 2020
859fc8e
Written nGraph NMS-5 without evaluate().
vgavrilo Oct 13, 2020
e52c21d
Merge remote-tracking branch 'upstream/master' into vgavrilo/ngraph-n…
vgavrilo Oct 13, 2020
94695ea
Used NGRAPH_RTTI_DECLARATION.
vgavrilo Oct 13, 2020
2579819
Merge remote-tracking branch 'upstream/master' into vgavrilo/ngraph-n…
vgavrilo Oct 13, 2020
8fd1cc0
Deleted include directive for reference implementation header.
vgavrilo Oct 13, 2020
70c2f76
Used NGRAPH_RTTI_DECLARATION for inner NMS-3.
vgavrilo Oct 13, 2020
5f51b6b
Small fix.
vgavrilo Oct 13, 2020
b8dadda
Merge remote-tracking branch 'upstream/master' into vgavrilo/ngraph-n…
vgavrilo Oct 14, 2020
a7c02a0
Deleted transformation NMS-5 -> inner NMSIE-3 and NMS-1, NMS-3, NMS-4…
vgavrilo Oct 14, 2020
348e654
Deleted transformations NMS-1, NMS-3, NMS-4 -> NMS-5 from common_opti…
vgavrilo Oct 14, 2020
2ea1592
Deleted transfomation NMS-5 -> NMSIE3 from convert_opset1_to_legacy.cpp.
vgavrilo Oct 14, 2020
41450ee
Merge remote-tracking branch 'upstream/master' into vgavrilo/ngraph-n…
vgavrilo Oct 14, 2020
fb73992
Merge remote-tracking branch 'upstream/master' into vgavrilo/ngraph-n…
vgavrilo Oct 14, 2020
eaf1912
Small fix.
vgavrilo Oct 14, 2020
fc3d17e
Merge remote-tracking branch 'upstream/master' into vgavrilo/ngraph-n…
vgavrilo Oct 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions inference-engine/src/transformations/include/ngraph_ops/nms_ie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,33 @@ class TRANSFORMATIONS_API NonMaxSuppressionIE2 : public NonMaxSuppressionIE {
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector & new_args) const override;
};

class TRANSFORMATIONS_API NonMaxSuppressionIE3 : public Op {
public:
NGRAPH_RTTI_DECLARATION;

NonMaxSuppressionIE3(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const Output<Node>& soft_nms_sigma,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type = ngraph::element::i64);

void validate_and_infer_types() override;

bool visit_attributes(AttributeVisitor& visitor) override;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector & new_args) const override;

int m_center_point_box;
bool m_sort_result_descending = true;
element::Type m_output_type;

private:
int64_t max_boxes_output_from_input() const;
};

} // namespace op
} // namespace ngraph
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>

#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API ConvertNMS5ToLegacyMatcher;

} // namespace pass
} // namespace ngraph

/*
* Description:
* Convert NMS-5 directly to inner NMS.
*/


class ngraph::pass::ConvertNMS5ToLegacyMatcher: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertNMS5ToLegacyMatcher();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <utility>

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API ConvertNMS1ToNMS5;
class TRANSFORMATIONS_API ConvertNMS3ToNMS5;
class TRANSFORMATIONS_API ConvertNMS4ToNMS5;

} // namespace pass
} // namespace ngraph

class ngraph::pass::ConvertNMS1ToNMS5: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertNMS1ToNMS5();
};

class ngraph::pass::ConvertNMS3ToNMS5: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertNMS3ToNMS5();
};

class ngraph::pass::ConvertNMS4ToNMS5: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertNMS4ToNMS5();
};
71 changes: 71 additions & 0 deletions inference-engine/src/transformations/src/ngraph_ops/nms_ie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,74 @@ void op::NonMaxSuppressionIE2::validate_and_infer_types() {
m_output_type);
set_output_type(0, nms->output(0).get_element_type(), nms->output(0).get_partial_shape());
}

NGRAPH_RTTI_DEFINITION(op::NonMaxSuppressionIE3, "NonMaxSuppressionIE", 3);

op::NonMaxSuppressionIE3::NonMaxSuppressionIE3(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const Output<Node>& soft_nms_sigma,
int center_point_box,
bool sort_result_descending,
const ngraph::element::Type& output_type)
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, soft_nms_sigma}),
m_center_point_box(center_point_box), m_sort_result_descending(sort_result_descending), m_output_type(output_type) {
constructor_validate_and_infer_types();
}

std::shared_ptr<Node> op::NonMaxSuppressionIE3::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
check_new_args_count(this, new_args);
return make_shared<NonMaxSuppressionIE3>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
new_args.at(4), new_args.at(5), m_center_point_box, m_sort_result_descending,
m_output_type);
}

bool op::NonMaxSuppressionIE3::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("center_point_box", m_center_point_box);
visitor.on_attribute("sort_result_descending", m_sort_result_descending);
visitor.on_attribute("output_type", m_output_type);
return true;
}

static constexpr size_t boxes_port = 0;
static constexpr size_t scores_port = 1;
static constexpr size_t max_output_boxes_per_class_port = 2;

int64_t op::NonMaxSuppressionIE3::max_boxes_output_from_input() const {
int64_t max_output_boxes{0};

const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
max_output_boxes = max_output_boxes_input->cast_vector<int64_t>().at(0);

return max_output_boxes;
}

void op::NonMaxSuppressionIE3::validate_and_infer_types() {
const auto boxes_ps = get_input_partial_shape(boxes_port);
const auto scores_ps = get_input_partial_shape(scores_port);

// NonMaxSuppression produces triplets
// that have the following format: [batch_index, class_index, box_index]
PartialShape out_shape = {Dimension::dynamic(), 3};

if (boxes_ps.rank().is_static() && scores_ps.rank().is_static()) {
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(max_output_boxes_per_class_port).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() &&
op::is_constant(max_output_boxes_per_class_node)) {
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();
const auto max_output_boxes_per_class = max_boxes_output_from_input();

out_shape[0] = std::min(num_boxes, max_output_boxes_per_class) * num_classes *
scores_ps[0].get_length();
}
}

set_output_type(0, m_output_type, out_shape);
set_output_type(1, element::f32, out_shape);
set_output_type(2, m_output_type, Shape{1});
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "transformations/remove_filtering_boxes_by_size.hpp"
#include "transformations/hswish_decomposition.hpp"
#include "transformations/hswish_fusion.hpp"
#include "transformations/convert_previous_nms_to_nms_5.hpp"

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
Expand Down Expand Up @@ -105,6 +106,10 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
manager.register_pass<ngraph::pass::ConstantFolding>();

manager.register_pass<ngraph::pass::ConvertNMS1ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();

auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeReshapeFusion>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <memory>
#include <vector>

#include <ngraph/graph_util.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph_ops/nms_ie.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations/utils/utils.hpp>

#include "transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNMS5ToLegacyMatcher, "ConvertNMS5ToLegacyMatcher", 0);

ngraph::pass::ConvertNMS5ToLegacyMatcher::ConvertNMS5ToLegacyMatcher() {
auto boxes = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1000, 4});
auto scores = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1000});
auto max_output_boxes_per_class = ngraph::opset5::Constant::create(element::i64, Shape{}, {10});
auto iou_threshold = ngraph::opset5::Constant::create(element::f32, Shape{}, {0.75});
auto score_threshold = ngraph::opset5::Constant::create(element::f32, Shape{}, {0.7});
auto soft_nms_sigma = ngraph::opset5::Constant::create(element::f32, Shape{}, {0.25});
auto nms = std::make_shared<ngraph::opset5::NonMaxSuppression>(boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold, soft_nms_sigma);

ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
auto nms_5 = std::dynamic_pointer_cast<ngraph::opset5::NonMaxSuppression>(m.get_match_root());
if (!nms_5) {
return false;
}

const auto new_args = nms_5->input_values();
const auto& arg2 = new_args.size() > 2 ? new_args.at(2) : ngraph::opset5::Constant::create(element::i32, Shape{}, {0});
const auto& arg3 = new_args.size() > 3 ? new_args.at(3) : ngraph::opset5::Constant::create(element::f32, Shape{}, {.0f});
const auto& arg4 = new_args.size() > 4 ? new_args.at(4) : ngraph::opset5::Constant::create(element::f32, Shape{}, {.0f});
const auto& arg5 = new_args.size() > 5 ? new_args.at(5) : ngraph::opset5::Constant::create(element::f32, Shape{}, {.0f});

const auto max_output_boxes_per_class_rank = arg2.get_partial_shape().rank();
const auto iou_threshold_rank = arg3.get_partial_shape().rank();
const auto score_threshold_rank = arg4.get_partial_shape().rank();
const auto soft_nms_sigma_rank = arg5.get_partial_shape().rank();

// Check that required ranks are not dynamic
if (max_output_boxes_per_class_rank.is_dynamic() ||
iou_threshold_rank.is_dynamic() ||
score_threshold_rank.is_dynamic() ||
soft_nms_sigma_rank.is_dynamic()) {
return false;
}

if (max_output_boxes_per_class_rank.get_length() == 1 &&
iou_threshold_rank.get_length() == 1 &&
score_threshold_rank.get_length() == 1 &&
soft_nms_sigma_rank.get_length() == 1) {
return false;
}

// vector of new nGraph operations
NodeVector new_ops;

auto new_max_per_class = arg2;
if (max_output_boxes_per_class_rank.get_length() == 0) {
// WA: we need to create Constant manually because it requires by NMS shape inference
// otherwise we will get dynamic shape until first CF is executed. It can be resolved
// if CF will be executed right after transformation and before Validate pass.
if (auto new_max_per_class_const = std::dynamic_pointer_cast<opset1::Constant>(new_max_per_class.get_node_shared_ptr())) {
new_max_per_class = opset1::Constant::create(element::i64, Shape{1}, new_max_per_class_const->cast_vector<int64_t>());
} else {
new_max_per_class = std::make_shared<ngraph::op::Unsqueeze>(arg2, opset1::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(new_max_per_class.get_node_shared_ptr());
}
}
auto new_iou_threshold = arg3;
if (iou_threshold_rank.get_length() == 0) {
new_iou_threshold = std::make_shared<ngraph::op::Unsqueeze>(arg3, opset1::Constant::create(element::f32, Shape{1}, {0.0f}));
new_ops.push_back(new_iou_threshold.get_node_shared_ptr());
}
auto new_score_threshold = arg4;
if (score_threshold_rank.get_length() == 0) {
new_score_threshold = std::make_shared<ngraph::op::Unsqueeze>(arg4, opset1::Constant::create(element::f32, Shape{1}, {0.0f}));
new_ops.push_back(new_score_threshold.get_node_shared_ptr());
}
auto new_soft_nms_sigma = arg5;
if (soft_nms_sigma_rank.get_length() == 0) {
new_soft_nms_sigma = std::make_shared<ngraph::op::Unsqueeze>(arg5, opset1::Constant::create(element::f32, Shape{1}, {0.0f}));
new_ops.push_back(new_soft_nms_sigma.get_node_shared_ptr());
}
int center_point_box = 0;
switch (nms_5->get_box_encoding()) {
case ::ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER:
center_point_box = 1;
break;
case ::ngraph::opset5::NonMaxSuppression::BoxEncodingType::CORNER:
center_point_box = 0;
break;
default:
throw ngraph_error("NonMaxSuppression layer " + nms_5->get_friendly_name() +
" has unsupported box encoding");
}
const auto nms_legacy = std::make_shared<op::NonMaxSuppressionIE3>(
new_args.at(0),
new_args.at(1),
new_max_per_class,
new_iou_threshold,
new_score_threshold,
new_soft_nms_sigma,
center_point_box,
nms_5->get_sort_result_descending(),
nms_5->get_output_type());
new_ops.push_back(nms_legacy);

nms_legacy->set_friendly_name(nms_5->get_friendly_name());
ngraph::copy_runtime_info(nms_5, new_ops);
ngraph::replace_node(nms_5, nms_legacy);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(nms, "ConvertNMS5ToNMSLegacy");
this->register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <transformations/convert_negative.hpp>
#include <transformations/convert_opset1_to_legacy/convert_nms_to_nms_ie.hpp>
#include <transformations/convert_opset1_to_legacy/convert_nms_4_to_legacy.hpp>
#include <transformations/convert_opset1_to_legacy/convert_nms_5_to_legacy.hpp>
#include <transformations/convert_opset1_to_legacy/convert_normalizel2_to_normalize_ie.hpp>
#include <transformations/convert_opset1_to_legacy/convert_one_hot_to_one_hot_ie.hpp>
#include <transformations/convert_opset1_to_legacy/convert_pad_to_pad_ie.hpp>
Expand Down Expand Up @@ -140,6 +141,7 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertNMSToNMSIEMatcher>();
anchor->add_matcher<ngraph::pass::ConvertNMS4ToLegacyMatcher>();
anchor->add_matcher<ngraph::pass::ConvertNMS5ToLegacyMatcher>();
anchor->add_matcher<ngraph::pass::ConvertGRUSequenceMatcher>();
anchor->add_matcher<ngraph::pass::ConvertRNNSequenceMatcher>();
anchor->add_matcher<ngraph::pass::ConvertLSTMSequenceMatcher>();
Expand Down
Loading