Skip to content

Commit

Permalink
Shell implementation for RandomUniform. (openvinotoolkit#6782)
Browse files Browse the repository at this point in the history
* Added shell implementation for RandomUniform.

* Small correction.

* Small correction.

* Corrected wrong type.

* Corrected error message, corrected setters.
  • Loading branch information
popovaan authored Jul 29, 2021
1 parent a3d9f00 commit e70e7e1
Show file tree
Hide file tree
Showing 7 changed files with 508 additions and 1 deletion.
69 changes: 69 additions & 0 deletions ngraph/core/include/ngraph/op/random_uniform.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"

namespace ngraph
{
namespace op
{
namespace v8
{
/// \brief Tensor RandomUniform operation.
class NGRAPH_API RandomUniform : public Op
{
public:
NGRAPH_RTTI_DECLARATION;

RandomUniform() = default;

///
/// \brief Constructs a RandomUniform operation.
///
/// \param out_shape Node producing the tensor with output shape.
/// \param min_val Node producing the tensor with minimum value.
/// \param max_val Node producing the tensor with maximum value.
/// \param out_type Output type of the tensor.
/// \param global_seed Global seed value.
/// \param op_seed Operational seed value.
RandomUniform(const Output<Node>& out_shape,
const Output<Node>& min_val,
const Output<Node>& max_val,
const ngraph::element::Type& out_type,
uint64_t global_seed,
uint64_t op_seed);

void validate_and_infer_types() override;

bool visit_attributes(AttributeVisitor& visitor) override;

std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;

/// \return The output tensor type.
const ngraph::element::Type& get_out_type() const { return m_output_type; }
void set_out_type(const ngraph::element::Type& output_type)
{
m_output_type = output_type;
}

/// \return The global seed value.
uint64_t get_global_seed() const { return m_global_seed; }
void set_global_seed(uint64_t seed) { m_global_seed = seed; }

/// \return The operational seed value.
uint64_t get_op_seed() const { return m_op_seed; }
void set_op_seed(uint64_t seed2) { m_op_seed = seed2; }

protected:
ngraph::element::Type m_output_type;
uint64_t m_global_seed;
uint64_t m_op_seed;
};
} // namespace v8
} // namespace op
} // namespace ngraph
1 change: 1 addition & 0 deletions ngraph/core/include/ngraph/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
#include "ngraph/op/prior_box_clustered.hpp"
#include "ngraph/op/proposal.hpp"
#include "ngraph/op/psroi_pooling.hpp"
#include "ngraph/op/random_uniform.hpp"
#include "ngraph/op/range.hpp"
#include "ngraph/op/read_value.hpp"
#include "ngraph/op/reduce_l1.hpp"
Expand Down
3 changes: 2 additions & 1 deletion ngraph/core/include/ngraph/opsets/opset8_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,5 @@ NGRAPH_OP(AdaptiveAvgPool, ngraph::op::v8)
NGRAPH_OP(AdaptiveMaxPool, ngraph::op::v8)
NGRAPH_OP(DeformableConvolution, ngraph::op::v8)
NGRAPH_OP(MatrixNms, ngraph::op::v8)
NGRAPH_OP(MulticlassNms, ngraph::op::v8)
NGRAPH_OP(MulticlassNms, ngraph::op::v8)
NGRAPH_OP(RandomUniform, ngraph::op::v8)
144 changes: 144 additions & 0 deletions ngraph/core/src/op/random_uniform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "ngraph/op/random_uniform.hpp"
#include <ngraph/validation_util.hpp>
#include "itt.hpp"

using namespace std;
using namespace ngraph;

NGRAPH_RTTI_DEFINITION(op::v8::RandomUniform, "RandomUniform", 8);

op::v8::RandomUniform::RandomUniform(const Output<Node>& out_shape,
const Output<Node>& min_val,
const Output<Node>& max_val,
const ngraph::element::Type& out_type,
uint64_t global_seed,
uint64_t op_seed)
: Op({out_shape, min_val, max_val})
, m_output_type(out_type)
, m_global_seed(global_seed)
, m_op_seed(op_seed)
{
constructor_validate_and_infer_types();
}

void op::v8::RandomUniform::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v8_RandomUniform_validate_and_infer_types);

const auto& shape_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
shape_et.is_dynamic() || shape_et == element::i32 ||
shape_et == element::i64,
"Type of the input should be int32 or int64.");

PartialShape output_shape = PartialShape::dynamic();
const auto& input_shape = get_input_partial_shape(0);
if (input_shape.rank().is_static())
{
NODE_VALIDATION_CHECK(this,
input_shape.rank() == 1,
"The rank of the tensor defining output shape must be equal to 1.");
if (const auto& const_shape = get_constant_from_source(input_value(0)))
{
output_shape = PartialShape(const_shape->cast_vector<int64_t>());
}
}

const auto& min_pshape = get_input_partial_shape(1);
const auto& max_pshape = get_input_partial_shape(2);
if (min_pshape.is_static())
{
const auto& min_rank = min_pshape.rank().get_length();
NODE_VALIDATION_CHECK(this, min_rank <= 1, "Min value must be a scalar or 1D tensor.");

if (min_rank == 1)
{
NODE_VALIDATION_CHECK(
this, min_pshape.compatible(Shape{1}), "'min_val' should have 1 element.");
}
}

if (max_pshape.is_static())
{
const auto& max_rank = max_pshape.rank().get_length();
NODE_VALIDATION_CHECK(this, max_rank <= 1, "Max value must be a scalar or 1D tensor.");

if (max_rank == 1)
{
NODE_VALIDATION_CHECK(
this, max_pshape.compatible(Shape{1}), "'max_val' should have 1 element.");
}
}

const element::Type& min_element_type = get_input_element_type(1);
element::Type max_element_type = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
min_element_type == max_element_type,
"'min_val' should have the same type as 'max_val'.");
NODE_VALIDATION_CHECK(
this,
min_element_type == get_out_type(),
"'min_val' and 'max_val' should have the same type as 'out_type' attribute.");

if (const auto& const_min = get_constant_from_source(input_value(1)))
{
if (const auto& const_max = get_constant_from_source(input_value(2)))
{
if (get_out_type() == ngraph::element::Type_t::i64 ||
get_out_type() == ngraph::element::Type_t::i32)
{
int64_t min_val = const_min->cast_vector<int64_t>()[0];
int64_t max_val = const_max->cast_vector<int64_t>()[0];

NODE_VALIDATION_CHECK(this,
min_val < max_val,
"Min value must be less than max value. Got "
"min value: ",
min_val,
", max value: ",
max_val);
}
else if (get_out_type().is_real())
{
double min_val = const_min->cast_vector<double>()[0];
double max_val = const_max->cast_vector<double>()[0];

NODE_VALIDATION_CHECK(this,
min_val < max_val,
"Min value must be less than max value. Got "
"min value: ",
min_val,
", max value: ",
max_val);
}
else
{
throw ngraph_error("Unsupported output type of RandomUniform: " +
get_out_type().get_type_name());
}
}
}

set_output_type(0, get_out_type(), output_shape);
}

bool op::v8::RandomUniform::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v8_RandomUniform_visit_attributes);
visitor.on_attribute("output_type", m_output_type);
visitor.on_attribute("op_seed", m_op_seed);
visitor.on_attribute("global_seed", m_global_seed);
return true;
}

shared_ptr<Node> op::v8::RandomUniform::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v8_Roll_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v8::RandomUniform>(
new_args[0], new_args[1], new_args[2], m_output_type, m_global_seed, m_op_seed);
}
2 changes: 2 additions & 0 deletions ngraph/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ set(SRC
type_prop/proposal.cpp
type_prop/psroi_pooling.cpp
type_prop/prior_box_clustered.cpp
type_prop/random_uniform.cpp
type_prop/range.cpp
type_prop/read_value.cpp
type_prop/reduce_l1.cpp
Expand Down Expand Up @@ -298,6 +299,7 @@ set(SRC
visitors/op/prior_box_clustered.cpp
visitors/op/proposal.cpp
visitors/op/psroi_pooling.cpp
visitors/op/random_uniform.cpp
visitors/op/reduce_l1.cpp
visitors/op/reduce_l2.cpp
visitors/op/reduce_logical_and.cpp
Expand Down
Loading

0 comments on commit e70e7e1

Please sign in to comment.