Skip to content

Commit

Permalink
Use wrapper function for pfnGetNativeBinary2
Browse files Browse the repository at this point in the history
  • Loading branch information
MirceaDan99 committed Sep 24, 2024
1 parent ab66b2b commit 8640ccc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
6 changes: 5 additions & 1 deletion src/plugins/intel_npu/src/backend/include/zero_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ struct ze_graph_dditable_ext_decorator final {
}

// version 1.7
ze_pfnGraphGetNativeBinary_ext_2_t pfnGetNativeBinary2;
ze_result_t ZE_APICALL pfnGetNativeBinary2(ze_graph_handle_t hGraph,
size_t* pSize, uint8_t** pGraphNativeBinary) {
throwWhenUnsupported("pfnGetNativeBinary2", ZE_GRAPH_EXT_VERSION_1_7);
return _impl->pfnGetNativeBinary2(hDevice, pSize, pGraphNativeBinary);
}
};

using ze_graph_dditable_ext_curr_t = ze_graph_dditable_ext_decorator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ class LevelZeroCompilerInDriver final : public ICompiler {
std::vector<IODescriptor>& outputs) const;

template <typename T = TableExtension, typename std::enable_if_t<UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
void getNativeBinary(TableExtension* graphDdiTableExt,
ze_graph_handle_t graphHandle, std::vector<uint8_t>& blob,
uint8_t** blobPtr, size_t* blobSize) const;

template <typename T = TableExtension, typename std::enable_if_t<!UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
void getNativeBinary(TableExtension* graphDdiTableExt,
ze_graph_handle_t graphHandle, std::vector<uint8_t>& /* unusedBlob */,
uint8_t** blobPtr, size_t* blobSize) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ void LevelZeroCompilerInDriver<TableExtension>::release(std::shared_ptr<const Ne

template <typename TableExtension>
template <typename T, std::enable_if_t<UseCopyForNativeBinary(T), bool>>
void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(TableExtension* graphDdiTableExt,
ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob,
uint8_t** blobPtr, size_t* blobSize) const {
// Get blob size first
auto result = _graphDdiTableExt.pfnGetNativeBinary(graphHandle, blobSize, nullptr);
auto result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, blobSize, nullptr);
blob.resize(*blobSize);

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
Expand All @@ -383,7 +383,7 @@ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditabl
getLatestBuildError());

// Get blob data
result = _graphDdiTableExt.pfnGetNativeBinary(graphHandle, blobSize, blob.data());
result = _graphDdiTableExt->pfnGetNativeBinary(graphHandle, blobSize, blob.data());

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob data",
Expand All @@ -400,12 +400,12 @@ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditabl

template <typename TableExtension>
template <typename T, std::enable_if_t<!UseCopyForNativeBinary(T), bool>>
void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditable_ext_curr_t& graphDdiTableExt,
void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(TableExtension* graphDdiTableExt,
ze_graph_handle_t graphHandle,
std::vector<uint8_t>& /* unusedBlob */,
uint8_t** blobPtr, size_t* blobSize) const {
// Get blob ptr and size
auto result = _graphDdiTableExt.pfnGetNativeBinary2(graphHandle, blobSize, blobPtr);
auto result = _graphDdiTableExt->pfnGetNativeBinary2(graphHandle, blobSize, blobPtr);

OPENVINO_ASSERT(result == ZE_RESULT_SUCCESS,
"Failed to compile network. L0 pfnGetNativeBinary get blob size",
Expand All @@ -425,7 +425,7 @@ CompiledNetwork LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
_logger.info("LevelZeroCompilerInDriver getCompiledNetwork get blob from graphHandle");
ze_graph_handle_t graphHandle = static_cast<ze_graph_handle_t>(networkDescription->metadata.graphHandle);

uint8_t* blobPtr;
uint8_t* blobPtr = nullptr;
size_t blobSize = -1;
std::vector<uint8_t> blob;

Expand Down

0 comments on commit 8640ccc

Please sign in to comment.