Skip to content

Commit

Permalink
Conversion via CNNNetworkImpl ctor (openvinotoolkit#1222)
Browse files Browse the repository at this point in the history
* Added ctor for CNNNetworkImpl to convert from ngraphImpl

* Re-use in all places instead of manual conversion

* Hide convertToCNNNetworkImpl usage

* Remove useless test

* Fixed Gleb's comments
  • Loading branch information
ilya-lavrenov authored Jul 8, 2020
1 parent c39e32a commit 884389d
Show file tree
Hide file tree
Showing 15 changed files with 91 additions and 178 deletions.
19 changes: 3 additions & 16 deletions inference-engine/src/hetero_plugin/hetero_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
#include "hetero/hetero_plugin_config.hpp"
#include <cpp_interfaces/base/ie_plugin_base.hpp>
#include "hetero_executable_network.hpp"
#include "convert_function_to_cnn_network.hpp"
#include <generic_ie.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>

using namespace InferenceEngine;
using namespace InferenceEngine::PluginConfigParams;
Expand Down Expand Up @@ -63,8 +57,7 @@ InferenceEngine::ExecutableNetworkInternal::Ptr Engine::LoadExeNetworkImpl(const
}
DeviceMetaInformationMap metaDevices = GetDevicePlugins(it->second, tconfig);

auto function = network.getFunction();
if (function != nullptr) {
if (auto function = network.getFunction()) {
auto anyDeviceDoNotSupportNgraph =
std::any_of(std::begin(metaDevices), std::end(metaDevices),
[&] (const DeviceMetaInformationMap::value_type& metaDevice) {
Expand All @@ -74,15 +67,9 @@ InferenceEngine::ExecutableNetworkInternal::Ptr Engine::LoadExeNetworkImpl(const
return (clonedNetwork->getFunction() == nullptr);
});
if (anyDeviceDoNotSupportNgraph) {
auto clonedNetwork = cloneNetwork(network);
auto function = clonedNetwork->getFunction();
::ngraph::op::GenericIE::DisableReshape noReshape(function);
::ngraph::pass::CommonOptimizations().run_on_function(function);
::ngraph::pass::ConvertOpSet3ToOpSet2().run_on_function(function);
::ngraph::pass::ConvertOpSet2ToOpSet1().run_on_function(function);
::ngraph::pass::ConvertOpSet1ToLegacy().run_on_function(function);
auto cnnNetworkImpl = std::make_shared<details::CNNNetworkImpl>(network);
return std::make_shared<HeteroExecutableNetwork>(
*InferenceEngine::details::convertFunctionToICNNNetwork(function, *clonedNetwork),
*cnnNetworkImpl,
mergeConfigs(_config, config), this);
} else {
return std::make_shared<HeteroExecutableNetwork>(*cloneNetwork(network), mergeConfigs(_config, config), this);
Expand Down
39 changes: 4 additions & 35 deletions inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@
#include <ngraph/ngraph.hpp>
#include <ngraph/pass/get_output_element_elimination.hpp>
#include <set>
// #include <shape_infer/ie_reshaper.hpp>
#include <string>

#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
#include <transformations/convert_opset1_to_legacy/convert_one_hot_to_one_hot_ie.hpp>

#include "ngraph_ops/eltwise.hpp"
Expand All @@ -35,7 +30,6 @@
#include "ie_profiling.hpp"
#include "network_serializer.h"
#include "generic_ie.hpp"
#include "convert_function_to_cnn_network.hpp"
#include <shape_infer/built-in/ie_built_in_holder.hpp>

using namespace std;
Expand Down Expand Up @@ -110,12 +104,6 @@ void CNNNetworkNGraphImpl::createDataForResult(const ::ngraph::Output<::ngraph::
}
}

std::shared_ptr<ICNNNetwork> CNNNetworkNGraphImpl::getCNNNetwork() {
if (!cnnNetwork)
convertToCNNNetworkImpl();
return cnnNetwork;
}

CNNNetworkNGraphImpl::CNNNetworkNGraphImpl(const std::shared_ptr<Function>& nGraph)
: _ngraph_function(nGraph) {
// Restore usual attributes for ICNNNetwork
Expand Down Expand Up @@ -325,9 +313,7 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
}
_ngraph_function->validate_nodes_and_infer_types();

if (cnnNetwork) {
convertToCNNNetworkImpl();
} else {
{
auto specialized_ngraph_function = cloneFunction(true, inputShapes);
// Call this transformation because OneHot IE and nGraph have different output precisions
{
Expand Down Expand Up @@ -430,15 +416,7 @@ StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std
return DescriptionBuffer(UNEXPECTED, resp);
}

auto graph = cloneFunction();
// Disable shape inference (WA for generic operations)
::ngraph::op::GenericIE::DisableReshape noReshape(graph);

::ngraph::pass::CommonOptimizations().run_on_function(graph);
::ngraph::pass::ConvertOpSet3ToOpSet2().run_on_function(graph);
::ngraph::pass::ConvertOpSet2ToOpSet1().run_on_function(graph);
::ngraph::pass::ConvertOpSet1ToLegacy().run_on_function(graph);
network = InferenceEngine::details::convertFunctionToICNNNetwork(graph, *this);
network = std::make_shared<details::CNNNetworkImpl>(*this);
}
if (!network) return GENERAL_ERROR;
return network->serialize(xmlPath, binPath, resp);
Expand Down Expand Up @@ -492,15 +470,6 @@ StatusCode CNNNetworkNGraphImpl::setBatchSizeReshape(size_t size, ResponseDesc*

void CNNNetworkNGraphImpl::convertToCNNNetworkImpl() {
IE_PROFILING_AUTO_SCOPE(convertToCNNNetworkImpl)
if (cnnNetwork)
return;
auto graph = cloneFunction();
// Disable shape inference (WA for generic operations)
::ngraph::op::GenericIE::DisableReshape noReshape(graph);

::ngraph::pass::CommonOptimizations().run_on_function(graph);
::ngraph::pass::ConvertOpSet3ToOpSet2().run_on_function(graph);
::ngraph::pass::ConvertOpSet2ToOpSet1().run_on_function(graph);
::ngraph::pass::ConvertOpSet1ToLegacy().run_on_function(graph);
cnnNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(graph, *this);
if (!cnnNetwork)
cnnNetwork = std::make_shared<details::CNNNetworkImpl>(*this);
}
21 changes: 13 additions & 8 deletions inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {

void setInputInfo(InputInfo::Ptr data);

std::shared_ptr<ICNNNetwork> getCNNNetwork();

void addLayer(const CNNLayerPtr& layer) noexcept;

// public version
Expand Down Expand Up @@ -91,11 +89,11 @@ class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
noexcept override;

void convertToCNNNetworkImpl();
protected:
std::shared_ptr<::ngraph::Function> _ngraph_function;
virtual std::shared_ptr<::ngraph::Function> cloneFunction(bool constFolding = false, const std::map<std::string,
std::vector<size_t>>& inputShapes = {}) const;
protected:
std::shared_ptr<::ngraph::Function> _ngraph_function;

private:
std::map<std::string, DataPtr> _data;
InferenceEngine::InputsDataMap _inputData;
Expand All @@ -111,10 +109,18 @@ class INFERENCE_ENGINE_API_CLASS(CNNNetworkNGraphImpl): public ICNNNetwork {
*/
void createDataForResult(const ::ngraph::Output<::ngraph::Node>& output, const std::string& outName, DataPtr& ptr);

friend INFERENCE_ENGINE_API_CPP(std::shared_ptr<CNNNetworkImpl>)
/**
* @brief Converts ngraph::Function to old CNNNetworkImpl representation
*/
void convertToCNNNetworkImpl();

friend INFERENCE_ENGINE_API_CPP(void)
convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function>& graph,
const ICNNNetwork& nGraphImpl, bool keep_constant_inputs);
const ICNNNetwork& nGraphImpl,
CNNNetworkImpl* cnnNetworkImpl,
bool keep_constant_inputs);

friend class NGraphData;

/**
* @brief Reshape on the same shape
Expand All @@ -126,7 +132,6 @@ class TINGraphBody : public CNNNetworkNGraphImpl {
public:
explicit TINGraphBody(const std::shared_ptr<::ngraph::Function>& func): CNNNetworkNGraphImpl(func) {}

protected:
std::shared_ptr<::ngraph::Function> cloneFunction(bool constFolding, const std::map<std::string, std::vector<size_t>>& inputShapes) const override {
return _ngraph_function;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace details {
class INFERENCE_ENGINE_API_CLASS(CNNNetworkImpl): public ICNNNetwork {
public:
CNNNetworkImpl();
explicit CNNNetworkImpl(const ICNNNetwork & ngraphImpl);
~CNNNetworkImpl() override;

std::shared_ptr<::ngraph::Function> getFunction() noexcept override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ INFERENCE_ENGINE_API_CPP(std::shared_ptr<CNNNetworkImpl>)
convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function>& graph,
const ICNNNetwork &network, bool keep_constant_inputs = false);

INFERENCE_ENGINE_API_CPP(void)
convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function>& graph,
const ICNNNetwork &ngraphNetwork,
CNNNetworkImpl* cnnNetworkImpl,
bool keep_constant_inputs = false);


} // namespace details
} // namespace InferenceEngine
6 changes: 2 additions & 4 deletions inference-engine/src/legacy_api/include/graph_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ namespace InferenceEngine {
*/
class INFERENCE_ENGINE_API_CLASS(ConstTransformer) {
public:
explicit ConstTransformer(ICNNNetwork* _network);
explicit ConstTransformer(details::CNNNetworkImpl* _network);
explicit ConstTransformer(std::vector<DataPtr> &_inputs, std::vector<DataPtr> &_outputs);

virtual ~ConstTransformer() = default;

/**
* @brief calculates const layers, combines const subgraph into a single const layers
Expand All @@ -41,6 +37,8 @@ class INFERENCE_ENGINE_API_CLASS(ConstTransformer) {
void fullTrim();

protected:
ConstTransformer(std::vector<DataPtr> &_inputs, std::vector<DataPtr> &_outputs);

/**
* @brief collect all const layers with marking if it defines shape (1 - for shape, 0 - otherwise)
*/
Expand Down
23 changes: 23 additions & 0 deletions inference-engine/src/legacy_api/src/cnn_network_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
#include "network_serializer.h"
#include "details/ie_cnn_network_tools.h"

#include "generic_ie.hpp"
#include "cnn_network_ngraph_impl.hpp"
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset3_to_opset2/convert_opset3_to_opset2.hpp>
#include "convert_function_to_cnn_network.hpp"

using namespace std;
using namespace InferenceEngine;
using namespace InferenceEngine::details;
Expand Down Expand Up @@ -78,6 +86,21 @@ ICNNNetwork::~ICNNNetwork() {}

CNNNetworkImpl::CNNNetworkImpl() {}

CNNNetworkImpl::CNNNetworkImpl(const ICNNNetwork & ngraphImpl) {
auto ngraphImplPtr = dynamic_cast<const details::CNNNetworkNGraphImpl*>(&ngraphImpl);
IE_ASSERT(ngraphImplPtr != nullptr);
IE_ASSERT(ngraphImplPtr->getFunction() != nullptr);
auto graph = ngraphImplPtr->cloneFunction();
// Disable shape inference (WA for generic operations)
::ngraph::op::GenericIE::DisableReshape noReshape(graph);

::ngraph::pass::CommonOptimizations().run_on_function(graph);
::ngraph::pass::ConvertOpSet3ToOpSet2().run_on_function(graph);
::ngraph::pass::ConvertOpSet2ToOpSet1().run_on_function(graph);
::ngraph::pass::ConvertOpSet1ToLegacy().run_on_function(graph);
InferenceEngine::details::convertFunctionToICNNNetwork(graph, ngraphImpl, this, false);
}

CNNNetworkImpl::~CNNNetworkImpl() {
// In case of cycles, memory leaks occur: Layer holds shared_ptr<Data>, and vice versa.
// Added additional check on cycles.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

#include <debug.h>
#include <ngraph/opsets/opset1.hpp>
#include "transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp"
#include "transformations/utils/utils.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/rt_info/primitives_priority_attribute.hpp"
Expand Down Expand Up @@ -508,9 +507,10 @@ CNNLayerPtr InferenceEngine::details::CNNLayerCreator::create() {
return res;
}

std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function> &graph,
const ICNNNetwork &network,
bool keep_constant_inputs) {
void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function> &graph,
const ICNNNetwork &network,
CNNNetworkImpl* cnnNetworkImpl,
bool keep_constant_inputs) {
IE_PROFILING_AUTO_SCOPE(convertFunctionToICNNNetwork)
const auto createCNNLayer = [](const std::shared_ptr<::ngraph::Node> &node) -> CNNLayerPtr {
class NGraphCNNLayer: public CNNLayer {
Expand Down Expand Up @@ -698,7 +698,7 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
return ::ngraph::as_type_ptr<::ngraph::op::Result>(node) != nullptr;
};

const auto keep_input_info = [](std::shared_ptr<details::CNNNetworkImpl> &network, const DataPtr &inData) {
const auto keep_input_info = [](CNNNetworkImpl *network, const DataPtr &inData) {
InputInfo::Ptr info(new InputInfo());
info->setInputData(inData);
network->setInputInfo(info);
Expand All @@ -709,8 +709,7 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
InputsDataMap thisInputDataMap;
network.getInputsInfo(thisInputDataMap);

// Create network
auto cnnNetworkImpl = std::make_shared<details::CNNNetworkImpl>();
// Construct network
cnnNetworkImpl->setName(graph->get_friendly_name());

// Collect all names from current graph
Expand Down Expand Up @@ -913,7 +912,15 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
for (const auto &ext : ::ngraph::op::GenericIE::getExtensions(graph)) {
cnnNetworkImpl->AddExtension(ext, nullptr);
}
}

std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function> &graph,
const ICNNNetwork &network,
bool keep_constant_inputs) {
auto cnnNetworkImpl = std::make_shared<details::CNNNetworkImpl>();
convertFunctionToICNNNetwork(graph, network, cnnNetworkImpl.get(), keep_constant_inputs);
return cnnNetworkImpl;
}

} // namespace details
} // namespace InferenceEngine
14 changes: 0 additions & 14 deletions inference-engine/src/legacy_api/src/graph_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <mutex>
#include <algorithm>

#include <cnn_network_ngraph_impl.hpp>
#include "blob_factory.hpp"
#include "cnn_network_impl.hpp"
#include "graph_tools.hpp"
Expand Down Expand Up @@ -71,19 +70,6 @@ ConstTransformer::ConstTransformer(details::CNNNetworkImpl* _network)
THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with null pointer of network";
}

ConstTransformer::ConstTransformer(ICNNNetwork* _network) {
if (auto cnnNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(_network)) {
network = cnnNet;
} else if (auto nGraphNet = dynamic_cast<InferenceEngine::details::CNNNetworkNGraphImpl *>(_network)) {
if (auto cnnNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(nGraphNet->getCNNNetwork().get()))
network = cnnNet;
}
if (!network)
THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with unsupported network type";
inputs = get_inputs(network);
outputs = get_outputs(network);
}

ConstTransformer::ConstTransformer(std::vector<DataPtr> &_inputs, std::vector<DataPtr> &_outputs)
: network(nullptr), inputs(_inputs), outputs(_outputs) {
if (inputs.empty() || outputs.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <utility>
#include <vector>

#include "cnn_network_ngraph_impl.hpp"
#include "details/os/os_filesystem.hpp"
#include "ie_format_parser.h"
#include "ie_profiling.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ void FrontEnd::removeConstLayers(ie::ICNNNetwork& network) {
env.log->trace("Remove const layers");
VPU_LOGGER_SECTION(env.log);

ie::ConstTransformer(&network).fullTrim();
auto implNetwork = dynamic_cast<ie::details::CNNNetworkImpl *>(&network);
VPU_THROW_UNLESS(implNetwork != nullptr, "FrontEnd::removeConstLayers expects CNNNetworkImpl");

ie::ConstTransformer(implNetwork).fullTrim();
}

} // namespace vpu
Loading

0 comments on commit 884389d

Please sign in to comment.