Skip to content

Commit

Permalink
Use OV RTTI for ConversionExtensions
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Feb 5, 2025
1 parent c27f796 commit 15fc254
Show file tree
Hide file tree
Showing 37 changed files with 123 additions and 97 deletions.
1 change: 0 additions & 1 deletion cmake/developer_package/compile_flags/os_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ endfunction()
# ov_target_link_libraries_as_system(<TARGET NAME> <PUBLIC | PRIVATE | INTERFACE> <target1 target2 ...>)
#
function(ov_target_link_libraries_as_system TARGET_NAME LINK_TYPE)
message("Link to ${TARGET_NAME} using ${LINK_TYPE} the following ${ARGN}")
target_link_libraries(${TARGET_NAME} ${LINK_TYPE} ${ARGN})

# include directories as SYSTEM
Expand Down
12 changes: 11 additions & 1 deletion src/cmake/openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,18 @@ ov_add_vs_version_file(NAME ${TARGET_NAME} FILEDESCRIPTION "OpenVINO runtime lib

target_include_directories(${TARGET_NAME} PUBLIC
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/core/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/inference/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/common/include>)

# to be aligned with OpenVINO archive, where all headers are located in the same folder and
# exposed via openvino::runtime
target_include_directories(${TARGET_NAME} INTERFACE
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/common/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/inference/include>)
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/onnx/frontend/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/paddle/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/pytorch/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/tensorflow/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/tensorflow_lite/include>)

target_link_libraries(${TARGET_NAME} PRIVATE openvino::reference
openvino::shape_inference
Expand Down
15 changes: 12 additions & 3 deletions src/core/include/openvino/core/extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "openvino/core/core_visibility.hpp"
#include "openvino/core/type.hpp"
#include "openvino/core/rtti.hpp"

#define OPENVINO_EXTENSION_C_API OPENVINO_EXTERN_C OPENVINO_CORE_EXPORTS
#define OPENVINO_EXTENSION_API OPENVINO_CORE_EXPORTS
Expand All @@ -24,6 +25,14 @@ class Extension;
*/
class OPENVINO_API Extension {
public:
_OPENVINO_HIDDEN_METHOD static const DiscreteTypeInfo& get_type_info_static() {
static const ::ov::DiscreteTypeInfo type_info_static{"Extension"};
return type_info_static;
}
virtual const DiscreteTypeInfo& get_type_info() const {
return get_type_info_static();
}

using Ptr = std::shared_ptr<Extension>;

virtual ~Extension();
Expand All @@ -37,15 +46,15 @@ class OPENVINO_API Extension {
/**
* @brief The entry point for library with OpenVINO extensions
*
* @param vector of extensions
* @param ext of extensions
*/
OPENVINO_EXTENSION_C_API
void OV_CREATE_EXTENSION(std::vector<ov::Extension::Ptr>&);
void OV_CREATE_EXTENSION(std::vector<ov::Extension::Ptr>& ext);

/**
* @brief Macro generates the entry point for the library
*
* @param vector of extensions
* @param ext of extensions
*/
#define OPENVINO_CREATE_EXTENSIONS(extensions) \
OPENVINO_EXTENSION_C_API void OV_CREATE_EXTENSION(std::vector<ov::Extension::Ptr>& ext); \
Expand Down
7 changes: 1 addition & 6 deletions src/core/include/openvino/core/op_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ namespace ov {
class OPENVINO_API BaseOpExtension : public Extension {
public:
using Ptr = std::shared_ptr<BaseOpExtension>;
/**
* @brief Returns the type info of operation
*
* @return ov::DiscreteTypeInfo
*/
virtual const ov::DiscreteTypeInfo& get_type_info() const = 0;

/**
* @brief Method creates an OpenVINO operation
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class OPENVINO_API TensorInfoMemoryType : public RuntimeAttribute {

TensorInfoMemoryType() = default;

~TensorInfoMemoryType() override;

explicit TensorInfoMemoryType(const std::string& value) : value(value) {}

bool visit_attributes(AttributeVisitor& visitor) override {
Expand Down
4 changes: 0 additions & 4 deletions src/core/include/openvino/core/rtti.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,3 @@
_OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT, \
_OPENVINO_RTTI_WITH_TYPE_VERSION, \
_OPENVINO_RTTI_WITH_TYPE)(__VA_ARGS__))

/// Note: Please don't use this macros for new operations
#define BWDCMP_RTTI_DECLARATION
#define BWDCMP_RTTI_DEFINITION(CLASS)
4 changes: 1 addition & 3 deletions src/core/include/openvino/core/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ OPENVINO_API
std::ostream& operator<<(std::ostream& s, const Shape& shape);

template <>
class OPENVINO_API AttributeAdapter<ov::Shape> : public IndirectVectorValueAccessor<ov::Shape, std::vector<int64_t>>

{
class OPENVINO_API AttributeAdapter<ov::Shape> : public IndirectVectorValueAccessor<ov::Shape, std::vector<int64_t>> {
public:
OPENVINO_RTTI("AttributeAdapter<Shape>");

Expand Down
8 changes: 7 additions & 1 deletion src/core/include/openvino/core/type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,15 @@ struct AsTypePtr<std::shared_ptr<In>> {
};
} // namespace util


namespace frontend {
class ConversionExtensionBase;
} // frontend

/// Casts a std::shared_ptr<Value> to a std::shared_ptr<Type> if it is of type
/// Type, nullptr otherwise
template <typename T, typename U>
template <typename T, typename U,
typename = std::enable_if_t<!std::is_base_of_v<ov::frontend::ConversionExtensionBase, T>>>
auto as_type_ptr(const U& value) -> decltype(::ov::util::AsTypePtr<U>::template call<T>(value)) {
#ifdef OPENVINO_DYNAMIC_CAST
return ::ov::util::AsTypePtr<U>::template call<T>(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class OPENVINO_API PrecisionSensitive : public RuntimeAttribute {

PrecisionSensitive() = default;

~PrecisionSensitive() override;

bool is_copyable() const override {
return false;
}
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/op/util/symbolic_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class OPENVINO_API SkipInvalidation : public RuntimeAttribute {
public:
OPENVINO_RTTI("SkipInvalidation", "0", RuntimeAttribute);
SkipInvalidation() = default;
~SkipInvalidation() override;
bool is_copyable() const override {
return false;
}
Expand Down
2 changes: 2 additions & 0 deletions src/core/src/op/util/precision_sensitive_attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "openvino/op/util/precision_sensitive_attribute.hpp"

ov::PrecisionSensitive::~PrecisionSensitive() = default;

void ov::mark_as_precision_sensitive(ov::Input<ov::Node> node_input) {
auto& rt_info = node_input.get_rt_info();
rt_info[PrecisionSensitive::get_type_info_static()] = PrecisionSensitive{};
Expand Down
2 changes: 2 additions & 0 deletions src/core/src/op/util/symbolic_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "openvino/op/util/multi_subgraph_base.hpp"

ov::SkipInvalidation::~SkipInvalidation() = default;

void ov::skip_invalidation(const ov::Output<ov::Node>& output) {
output.get_tensor().get_rt_info()[ov::SkipInvalidation::get_type_info_static()] = nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions src/core/src/preprocess/pre_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ std::shared_ptr<Model> PrePostProcessor::build() {
return function;
}

// ------------------ TensorInfoMemoryType ----------------
TensorInfoMemoryType::~TensorInfoMemoryType() = default;

// --------------------- InputTensorInfo ------------------
InputTensorInfo::InputTensorInfo() : m_impl(std::unique_ptr<InputTensorInfoImpl>(new InputTensorInfoImpl())) {}
InputTensorInfo::~InputTensorInfo() = default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
namespace ov {
namespace frontend {

class FRONTEND_API ConversionExtensionBase : public ov::Extension {
class FRONTEND_API ConversionExtensionBase : public Extension {
public:
OPENVINO_RTTI("ConversionExtensionBase", "0", Extension);

using Ptr = std::shared_ptr<ConversionExtensionBase>;
explicit ConversionExtensionBase(const std::string& op_type) : m_op_type(op_type) {}

Expand All @@ -28,6 +30,8 @@ class FRONTEND_API ConversionExtensionBase : public ov::Extension {

class FRONTEND_API ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("ConversionExtension", "0", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;
ConversionExtension(const std::string& op_type, const CreatorFunction& converter)
: ConversionExtensionBase(op_type),
Expand Down Expand Up @@ -61,5 +65,11 @@ class FRONTEND_API ConversionExtension : public ConversionExtensionBase {
CreatorFunctionNamedAndIndexed m_converter_named_and_indexed;
};

template <typename Type, typename Value,
typename std::enable_if<std::is_same<Type, ConversionExtension>::value, bool>::type = true>
auto as_type_ptr(const Value& value) -> decltype(::ov::util::AsTypePtr<Value>::template call<Type>(value)) {
return std::dynamic_pointer_cast<Type>(value);
}

} // namespace frontend
} // namespace ov
3 changes: 2 additions & 1 deletion src/frontends/ir/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ void FrontEnd::add_extension(const ov::Extension::Ptr& ext) {
if (std::dynamic_pointer_cast<ov::BaseOpExtension>(so_ext->extension())) {
m_extensions.emplace_back(so_ext->extension());
}
} else if (std::dynamic_pointer_cast<ov::BaseOpExtension>(ext))
} else if (std::dynamic_pointer_cast<ov::BaseOpExtension>(ext)) {
m_extensions.emplace_back(ext);
}
}

InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ namespace ov {
namespace frontend {
namespace jax {

class JAX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
class ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("JaxConversionExtension", "0", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;

ConversionExtension() = delete;
Expand All @@ -27,12 +29,16 @@ class JAX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
return m_converter;
}

~ConversionExtension() override;

private:
ov::frontend::CreatorFunction m_converter;
};

template <typename Type, typename Value,
typename std::enable_if<std::is_same<Type, ConversionExtension>::value, bool>::type = true>
auto as_type_ptr(const Value& value) -> decltype(::ov::util::AsTypePtr<Value>::template call<Type>(value)) {
return ::ov::util::AsTypePtr<Value>::template call<Type>(value);
}

} // namespace jax
} // namespace frontend
} // namespace ov
7 changes: 0 additions & 7 deletions src/frontends/jax/src/extensions.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/frontends/jax/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
return conv_ext->get_converter()(context);
};
} else if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::jax::ConversionExtension>(extension)) {
} else if (auto conv_ext = as_type_ptr<ov::frontend::jax::ConversionExtension>(extension)) {
m_conversion_extensions.push_back(conv_ext);
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
return conv_ext->get_converter()(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
namespace ov {
namespace frontend {
namespace onnx {
class ONNX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
class ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("OnnxConversionExtension", "0", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;

ConversionExtension(const std::string& op_type, const ov::frontend::CreatorFunction& converter)
Expand All @@ -26,8 +28,6 @@ class ONNX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
m_domain{domain},
m_converter(converter) {}

~ConversionExtension() override;

const std::string& get_domain() const {
return m_domain;
}
Expand All @@ -37,9 +37,16 @@ class ONNX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
}

private:
std::string m_domain = "";
std::string m_domain;
ov::frontend::CreatorFunction m_converter;
};

template <typename Type, typename Value,
typename std::enable_if<std::is_same<Type, ConversionExtension>::value, bool>::type = true>
auto as_type_ptr(const Value& value) -> decltype(::ov::util::AsTypePtr<Value>::template call<Type>(value)) {
return ::ov::util::AsTypePtr<Value>::template call<Type>(value);
}

} // namespace onnx
} // namespace frontend
} // namespace ov
2 changes: 1 addition & 1 deletion src/frontends/onnx/frontend/src/core/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ OperatorsBridge register_extensions(OperatorsBridge& bridge,
return common_conv_ext->get_converter()(ov::frontend::onnx::NodeContext(node));
});
} else if (const auto onnx_conv_ext =
std::dynamic_pointer_cast<ov::frontend::onnx::ConversionExtension>(extension)) {
as_type_ptr<ov::frontend::onnx::ConversionExtension>(extension)) {
bridge.overwrite_operator(onnx_conv_ext->get_op_type(),
onnx_conv_ext->get_domain(),
[onnx_conv_ext](const ov::frontend::onnx::Node& node) -> ov::OutputVector {
Expand Down
7 changes: 0 additions & 7 deletions src/frontends/onnx/frontend/src/extensions.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/frontends/onnx/frontend/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
m_other_extensions.push_back(so_ext);
} else if (auto common_conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
m_extensions.conversions.push_back(common_conv_ext);
} else if (const auto onnx_conv_ext = std::dynamic_pointer_cast<onnx::ConversionExtension>(extension)) {
} else if (const auto onnx_conv_ext = as_type_ptr<onnx::ConversionExtension>(extension)) {
m_extensions.conversions.push_back(onnx_conv_ext);
} else if (auto progress_reporter = std::dynamic_pointer_cast<ProgressReporterExtension>(extension)) {
m_extensions.progress_reporter = progress_reporter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ namespace ov {
namespace frontend {
namespace paddle {

class PADDLE_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
class ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("PaddleConversionExtension", "0", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;

ConversionExtension() = delete;
Expand All @@ -23,8 +25,6 @@ class PADDLE_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
: ConversionExtensionBase(op_type),
m_converter(converter) {}

~ConversionExtension() override;

const ov::frontend::CreatorFunctionNamed& get_converter() const {
return m_converter;
}
Expand All @@ -33,6 +33,12 @@ class PADDLE_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
ov::frontend::CreatorFunctionNamed m_converter;
};

template <typename Type, typename Value,
typename std::enable_if<std::is_same<Type, ConversionExtension>::value, bool>::type = true>
auto as_type_ptr(const Value& value) -> decltype(::ov::util::AsTypePtr<Value>::template call<Type>(value)) {
return ::ov::util::AsTypePtr<Value>::template call<Type>(value);
}

} // namespace paddle
} // namespace frontend
} // namespace ov
7 changes: 0 additions & 7 deletions src/frontends/paddle/src/extensions.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/frontends/paddle/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
m_op_translators[common_conv_ext->get_op_type()] = [=](const NodeContext& context) {
return common_conv_ext->get_converter_named()(context);
};
} else if (const auto& paddle_conv_ext = std::dynamic_pointer_cast<ConversionExtension>(extension)) {
} else if (const auto& paddle_conv_ext = as_type_ptr<ConversionExtension>(extension)) {
m_conversion_extensions.push_back(paddle_conv_ext);
m_op_translators[paddle_conv_ext->get_op_type()] = [=](const NodeContext& context) {
return paddle_conv_ext->get_converter()(context);
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/paddle/tests/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PaddleFrontendWrapper : public ov::frontend::paddle::FrontEnd {
void add_extension(const std::shared_ptr<ov::Extension>& extension) override {
ov::frontend::paddle::FrontEnd::add_extension(extension);

if (auto conv_ext = std::dynamic_pointer_cast<ConversionExtension>(extension)) {
if (auto conv_ext = ov::as_type_ptr<ConversionExtension>(extension)) {
EXPECT_NE(std::find(m_conversion_extensions.begin(), m_conversion_extensions.end(), conv_ext),
m_conversion_extensions.end())
<< "ConversionExtension is not registered.";
Expand Down
Loading

0 comments on commit 15fc254

Please sign in to comment.