Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support runtime defined function wrapping of library module packed functions #9342

Merged
merged 6 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 85 additions & 45 deletions src/runtime/dso_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,62 +37,102 @@
namespace tvm {
namespace runtime {

// Dynamic shared libary.
// This is the default module TVM used for host-side AOT
/*!
* \brief Dynamic shared library object used to load
* and retrieve symbols by name. This is the default
* module TVM uses for host-side AOT compilation.
*/
class DSOLibrary final : public Library {
public:
~DSOLibrary() {
if (lib_handle_) Unload();
}
void Init(const std::string& name) { Load(name); }

void* GetSymbol(const char* name) final { return GetSymbol_(name); }
~DSOLibrary();
/*!
* \brief Initialize by loading and storing
* a handle to the underlying shared library.
* \param name The string name/path to the
* shared library over which to initialize.
*/
void Init(const std::string& name);
/*!
* \brief Returns the symbol address within
* the shared library for a given symbol name.
* \param name The name of the symbol.
* \return The symbol.
*/
void* GetSymbol(const char* name) final;

private:
// Platform dependent handling.
/*! \brief Private implementation of symbol lookup.
* Implementation is operating system dependent.
* \param The name of the symbol.
* \return The symbol.
*/
void* GetSymbol_(const char* name);
/*! \brief Implementation of shared library load.
* Implementation is operating system dependent.
* \param The name/path of the shared library.
*/
void Load(const std::string& name);
/*! \brief Implementation of shared library unload.
* Implementation is operating system dependent.
*/
void Unload();

#if defined(_WIN32)
// library handle
//! \brief Windows library handle
HMODULE lib_handle_{nullptr};

void* GetSymbol_(const char* name) {
return reinterpret_cast<void*>(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}

// Load the library
void Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name;
}

void Unload() {
FreeLibrary(lib_handle_);
lib_handle_ = nullptr;
}
#else
// Library handle
// \brief Linux library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name << " " << dlerror();
}

void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }

void Unload() {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}
#endif
};

TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) {
DSOLibrary::~DSOLibrary() {
if (lib_handle_) Unload();
}

void DSOLibrary::Init(const std::string& name) { Load(name); }

void* DSOLibrary::GetSymbol(const char* name) { return GetSymbol_(name); }

#if defined(_WIN32)

void* DSOLibrary::GetSymbol_(const char* name) {
return reinterpret_cast<void*>(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}

void DSOLibrary::Load(const std::string& name) {
// use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name;
}

void DSOLibrary::Unload() {
FreeLibrary(lib_handle_);
lib_handle_ = nullptr;
}

#else

void DSOLibrary::Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " "
<< dlerror();
}

void* DSOLibrary::GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }

void DSOLibrary::Unload() {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}

#endif

ObjectPtr<Library> CreateDSOLibraryObject(std::string library_path) {
auto n = make_object<DSOLibrary>();
n->Init(args[0]);
*rv = CreateModuleFromLibrary(n);
});
n->Init(library_path);
return n;
}
} // namespace runtime
} // namespace tvm
24 changes: 16 additions & 8 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace runtime {
// Library module that exposes symbols from a library.
class LibraryModuleNode final : public ModuleNode {
public:
explicit LibraryModuleNode(ObjectPtr<Library> lib) : lib_(lib) {}
explicit LibraryModuleNode(ObjectPtr<Library> lib, PackedFuncWrapper wrapper)
: lib_(lib), packed_func_wrapper_(wrapper) {}

const char* type_key() const final { return "library"; }

Expand All @@ -53,11 +54,12 @@ class LibraryModuleNode final : public ModuleNode {
faddr = reinterpret_cast<TVMBackendPackedCFunc>(lib_->GetSymbol(name.c_str()));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
return packed_func_wrapper_(faddr, sptr_to_self);
}

private:
ObjectPtr<Library> lib_;
PackedFuncWrapper packed_func_wrapper_;
};

/*!
Expand Down Expand Up @@ -128,7 +130,8 @@ Module LoadModuleFromBinary(const std::string& type_key, dmlc::Stream* stream) {
* \param root_module the output root module
* \param dso_ctx_addr the output dso module
*/
void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib, runtime::Module* root_module,
void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib,
PackedFuncWrapper packed_func_wrapper, runtime::Module* root_module,
runtime::ModuleNode** dso_ctx_addr = nullptr) {
ICHECK(mblob != nullptr);
uint64_t nbytes = 0;
Expand All @@ -152,7 +155,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib, runtime::Modul
// "_lib" serves as a placeholder in the module import tree to indicate where
// to place the DSOModule
if (tkey == "_lib") {
auto dso_module = Module(make_object<LibraryModuleNode>(lib));
auto dso_module = Module(make_object<LibraryModuleNode>(lib, packed_func_wrapper));
*dso_ctx_addr = dso_module.operator->();
++num_dso_module;
modules.emplace_back(dso_module);
Expand All @@ -170,7 +173,7 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib, runtime::Modul
// if we are using old dll, we don't have import tree
// so that we can't reconstruct module relationship using import tree
if (import_tree_row_ptr.empty()) {
auto n = make_object<LibraryModuleNode>(lib);
auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper);
auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->());
for (const auto& m : modules) {
module_import_addr->emplace_back(m);
Expand All @@ -194,17 +197,17 @@ void ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib, runtime::Modul
}
}

Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper packed_func_wrapper) {
InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); });
auto n = make_object<LibraryModuleNode>(lib);
auto n = make_object<LibraryModuleNode>(lib, packed_func_wrapper);
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(lib->GetSymbol(runtime::symbol::tvm_dev_mblob));

Module root_mod;
runtime::ModuleNode* dso_ctx_addr = nullptr;
if (dev_mblob != nullptr) {
ProcessModuleBlob(dev_mblob, lib, &root_mod, &dso_ctx_addr);
ProcessModuleBlob(dev_mblob, lib, packed_func_wrapper, &root_mod, &dso_ctx_addr);
} else {
// Only have one single DSO Module
root_mod = Module(n);
Expand All @@ -218,5 +221,10 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {

return root_mod;
}

TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectPtr<Library> n = CreateDSOLibraryObject(args[0]);
*rv = CreateModuleFromLibrary(n);
});
} // namespace runtime
} // namespace tvm
21 changes: 20 additions & 1 deletion src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,35 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
*/
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);

/*!
* \brief Type alias for funcion to wrap a TVMBackendPackedCFunc.
* \param The function address imported from a module.
* \param mptr The module pointer node.
* \return Packed function that wraps the invocation of the function at faddr.
*/
using PackedFuncWrapper =
std::function<PackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& mptr)>;

/*! \brief Return a library object interface over dynamic shared
* libraries in Windows and Linux providing support for
* loading/unloading and symbol lookup.
* \param Full path to shared library.
* \return Returns pointer to the Library providing symbol lookup.
*/
ObjectPtr<Library> CreateDSOLibraryObject(std::string library_path);

/*!
* \brief Create a module from a library.
*
* \param lib The library.
* \param wrapper Optional function used to wrap a TVMBackendPackedCFunc,
* by default WrapPackedFunc is used.
* \return The corresponding loaded module.
*
* \note This function can create multiple linked modules
* by parsing the binary blob section of the library.
*/
Module CreateModuleFromLibrary(ObjectPtr<Library> lib);
Module CreateModuleFromLibrary(ObjectPtr<Library> lib, PackedFuncWrapper wrapper = WrapPackedFunc);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_LIBRARY_MODULE_H_