Skip to content

Commit

Permalink
Unary Visitor test template fix
Browse files Browse the repository at this point in the history
-Migrate OP Tanh to use RTTI;
-Remove the using namespace in the header file
-Migrate the Swish and Tanh visitor test to use template code


Signed-off-by: Luwei Zhou <[email protected]>
  • Loading branch information
luweizhou2016 committed Jul 1, 2021
1 parent e7b0fc2 commit b686c93
Show file tree
Hide file tree
Showing 16 changed files with 34 additions and 53 deletions.
3 changes: 1 addition & 2 deletions ngraph/core/include/ngraph/op/tanh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ namespace ngraph
class NGRAPH_API Tanh : public util::UnaryElementwiseArithmetic
{
public:
static constexpr NodeTypeInfo type_info{"Tanh", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a hyperbolic tangent operation.
///
/// \param arg Node that produces the input tensor.
Expand Down
2 changes: 1 addition & 1 deletion ngraph/core/src/op/tanh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
using namespace std;
using namespace ngraph;

constexpr NodeTypeInfo op::Tanh::type_info;
NGRAPH_RTTI_DEFINITION(op::v0::Tanh, "Tanh", 0, util::UnaryElementwiseArithmetic);

op::Tanh::Tanh(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/ceiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "unary_ops.hpp"
using namespace ngraph;
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Ceiling, element::f32>>;

INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_attribute,
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/cos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "unary_ops.hpp"
using namespace ngraph;
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Cos, element::f32>>;

INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_attribute,
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/erf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "unary_ops.hpp"

using namespace ngraph;
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Erf, element::f32>>;

INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_atrribute,
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/floor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "unary_ops.hpp"

using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Floor, element::f32>,
UnaryOperatorType<ngraph::op::v0::Floor, element::f16>>;

Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "unary_ops.hpp"
using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Log, element::f32>,
UnaryOperatorType<ngraph::op::v0::Log, element::f16>>;

Expand Down
26 changes: 6 additions & 20 deletions ngraph/test/visitors/op/mish.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "unary_ops.hpp"

#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "util/visitor.hpp"

using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;

TEST(attributes, mish_op)
{
NodeBuilder::get_ops().register_factory<opset4::Mish>();
const auto A = make_shared<op::Parameter>(element::f32, Shape{5, 2});

const auto mish = make_shared<opset4::Mish>(A);
NodeBuilder builder(mish);
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v4::Mish, element::f32>>;

const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_attribute,
UnaryOperatorVisitor,
Type,
UnaryOperatorTypeName);
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/negative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "unary_ops.hpp"

using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Negative, element::f32>,
UnaryOperatorType<ngraph::op::v0::Negative, element::i32>>;

Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "unary_ops.hpp"

using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Result, element::f32>,
UnaryOperatorType<ngraph::op::v0::Result, element::f16>>;

Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/softplus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "unary_ops.hpp"

using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v4::SoftPlus, element::f32>>;

INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_atrribute,
Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/visitors/op/sqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//

#include "unary_ops.hpp"

using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Sqrt, element::f32>,
UnaryOperatorType<ngraph::op::v0::Sqrt, element::f16>>;

Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/visitors/op/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//

#include "unary_ops.hpp"

using namespace ngraph;
using Types = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Squeeze, element::f32>,
UnaryOperatorType<ngraph::op::v0::Squeeze, element::f16>>;

Expand Down
1 change: 1 addition & 0 deletions ngraph/test/visitors/op/swish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
#include "unary_ops.hpp"

using namespace ngraph;
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v4::Swish, element::f32>>;

INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_atrribute,
Expand Down
27 changes: 6 additions & 21 deletions ngraph/test/visitors/op/tanh.cpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "unary_ops.hpp"

#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset1.hpp"

#include "util/visitor.hpp"

using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;

TEST(attributes, tanh_op)
{
NodeBuilder::get_ops().register_factory<op::Tanh>();
const auto data_node = make_shared<op::Parameter>(element::f32, Shape{1});
const auto tanh = make_shared<op::Tanh>(data_node);

const NodeBuilder builder(tanh);
const auto tanh_attr_number = 0;
using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v0::Tanh, element::f32>>;

EXPECT_EQ(builder.get_value_map_size(), tanh_attr_number);
}
INSTANTIATE_TYPED_TEST_SUITE_P(visitor_without_attribute,
UnaryOperatorVisitor,
Type,
UnaryOperatorTypeName);
16 changes: 7 additions & 9 deletions ngraph/test/visitors/op/unary_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
#include "ngraph/op/util/attr_types.hpp"
#include "util/visitor.hpp"

using namespace ngraph;
using ngraph::test::NodeBuilder;
template <typename T, element::Type_t ELEMENT_TYPE>
template <typename T, ngraph::element::Type_t ELEMENT_TYPE>
class UnaryOperatorType
{
public:
using op_type = T;
static constexpr element::Type_t element_type = ELEMENT_TYPE;
static constexpr ngraph::element::Type_t element_type = ELEMENT_TYPE;
};
template <typename T>
class UnaryOperatorVisitor : public testing::Test
Expand All @@ -32,7 +30,7 @@ class UnaryOperatorTypeName
static std::string GetName(int)
{
using OP_Type = typename T::op_type;
constexpr element::Type precision(T::element_type);
constexpr ngraph::element::Type precision(T::element_type);
const ngraph::Node::type_info_t typeinfo = OP_Type::get_type_info_static();
std::string op_name{typeinfo.name};
op_name.append("_");
Expand All @@ -45,13 +43,13 @@ TYPED_TEST_SUITE_P(UnaryOperatorVisitor);
TYPED_TEST_P(UnaryOperatorVisitor, No_Attribute_4D)
{
using OP_Type = typename TypeParam::op_type;
const element::Type_t element_type = TypeParam::element_type;
const ngraph::element::Type_t element_type = TypeParam::element_type;

NodeBuilder::get_ops().register_factory<OP_Type>();
const auto A = std::make_shared<op::Parameter>(element_type, PartialShape{2, 2, 2, 2});
ngraph::test::NodeBuilder::get_ops().register_factory<OP_Type>();
const auto A = std::make_shared<ngraph::op::Parameter>(element_type, ngraph::PartialShape{2, 2, 2, 2});

const auto op_func = std::make_shared<OP_Type>(A);
NodeBuilder builder(op_func);
ngraph::test::NodeBuilder builder(op_func);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}
Expand Down

0 comments on commit b686c93

Please sign in to comment.