Skip to content

Commit

Permalink
[LPT] ConvertSubtract fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 20, 2020
1 parent e1a5cee commit e87eb4f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,29 @@ ngraph::pass::ConvertSubtract::ConvertSubtract() {
if (!sub) {
return false;
}
if (sub->output(0).get_target_inputs().empty()) {
return false;
}

if (sub->input(0).get_element_type() != sub->input(1).get_element_type()) {
return false;
}

std::shared_ptr<Node> child = sub->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
if (child->output(0).get_target_inputs().empty()) {
return false;
}
std::shared_ptr<Node> childchild = child->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
if (is_type<opset1::Convolution>(child) ||
is_type<opset1::GroupConvolution>(child) ||
is_type<opset1::MatMul>(child) ||
(is_type<opset1::Reshape>(child) && is_type<opset1::GroupConvolution>(childchild))) {
const auto input1Type = sub->input(0).get_element_type();
const auto input2Type = sub->input(1).get_element_type();
if (((input1Type == element::u8) && (input2Type == element::u8)) ||
((input1Type == element::i8) && (input2Type == element::i8))) {
// we should not execute transformation by reasons:
// 1. LPT asymmetric quantization pattern has to be keep as is
// 2. Subtract operation has unsigned/signed integer value which is not safe to multiply by -1
return false;
if (sub->input(0).get_element_type() == sub->input(1).get_element_type()) {
const auto subChildren = sub->output(0).get_target_inputs();
if (subChildren.size() == 1ul) {
const std::shared_ptr<Node> child = subChildren.begin()->get_node()->shared_from_this();
if (child != nullptr) {
if (is_type<opset1::Convolution>(child) ||
is_type<opset1::GroupConvolution>(child) ||
is_type<opset1::MatMul>(child) ||
(is_type<opset1::Reshape>(child) &&
(child->output(0).get_target_inputs().size() == 1ul) &&
is_type<opset1::GroupConvolution>(child->output(0).get_target_inputs().begin()->get_node()->shared_from_this()))) {
const auto input1Type = sub->input(0).get_element_type();
const auto input2Type = sub->input(1).get_element_type();
if (((input1Type == element::u8) && (input2Type == element::u8)) ||
((input1Type == element::i8) && (input2Type == element::i8))) {
// we should not execute transformation by reasons:
// 1. LPT asymmetric quantization pattern has to be keep as is
// 2. Subtract operation has unsigned/signed integer value which is not safe to multiply by -1
return false;
}
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ TEST(TransformationTests, LogSoftmaxDecomposition) {
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{sub_end}, ngraph::ParameterVector{input0});
}

auto res = compare_functions(f, f_ref);
auto res = compare_functions(f, f_ref, false, false, false, false);
ASSERT_TRUE(res.first) << res.second;
}

0 comments on commit e87eb4f

Please sign in to comment.