Skip to content

Commit

Permalink
Ignore allocator type in ExecutionProviders allocator map. Make defau…
Browse files Browse the repository at this point in the history
…lt initialization of OrtMemoryInfo more clearly invalid. (#2768)

* Remove allocator type from the key comparison in ExecutionProviders.
Remove usage of DummyArena as it's no longer necessary.

* Fix x86 tests where arena allocator is disabled.
Make initialization of OrtMemoryInfo clearer by adding Invalid enum value.

* Make OrtValueNameIdxMap::MaxIdx more intuitive.
  • Loading branch information
skottmckay authored Jan 14, 2020
1 parent b308e82 commit 98cb41a
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 121 deletions.
40 changes: 24 additions & 16 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,27 @@ struct OrtMemoryInfo {

// use string for name, so we could have customized allocator in execution provider.
const char* name;
int id;
OrtMemType mem_type;
OrtAllocatorType type;
int id = -1;
OrtMemType mem_type = OrtMemTypeDefault;
OrtAllocatorType alloc_type = Invalid;
OrtDevice device;

constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault)
constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0,
OrtMemType mem_type_ = OrtMemTypeDefault)
#if (defined(__GNUC__) || defined(__clang__))
__attribute__((nonnull))
#endif
: name(name_),
id(id_),
mem_type(mem_type_),
type(type_),
alloc_type(type_),
device(device_) {
}

// To make OrtMemoryInfo become a valid key in std map
inline bool operator<(const OrtMemoryInfo& other) const {
if (type != other.type)
return type < other.type;
bool operator<(const OrtMemoryInfo& other) const {
if (alloc_type != other.alloc_type)
return alloc_type < other.alloc_type;
if (mem_type != other.mem_type)
return mem_type < other.mem_type;
if (id != other.id)
Expand All @@ -113,20 +114,22 @@ struct OrtMemoryInfo {
return strcmp(name, other.name) < 0;
}

inline std::string ToString() const {
std::string ToString() const {
std::ostringstream ostr;
ostr << "OrtMemoryInfo: ["
<< " name:" << name
<< " id:" << id
<< " mem_type:" << mem_type
<< " type:" << type
<< " alloc_type:" << alloc_type
<< "]";
return ostr.str();
}
};

inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) {
return left.mem_type == other.mem_type && left.type == other.type && left.id == other.id &&
return left.mem_type == other.mem_type &&
left.alloc_type == other.alloc_type &&
left.id == other.id &&
strcmp(left.name, other.name) == 0;
}

Expand Down Expand Up @@ -213,9 +216,11 @@ class IAllocator {
if (!std::is_void<T>::value) {
// sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
// reachable if T is void. use std::conditional to 'use' void* in the sizeof call
if (!CalcMemSizeForArray(count_or_bytes, sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type),
if (!CalcMemSizeForArray(count_or_bytes,
sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type),
&alloc_size)) return nullptr;
}

return IAllocatorUniquePtr<T>{
static_cast<T*>(allocator->Alloc(alloc_size)), // allocate
[=](T* ptr) { allocator->Free(ptr); }}; // capture IAllocator so it's always valid, and use as deleter
Expand All @@ -224,22 +229,26 @@ class IAllocator {

template <size_t alignment>
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept {
static constexpr size_t max_allowed = (static_cast<size_t>(1) << (static_cast<size_t>(std::numeric_limits<size_t>::digits >> 1))) - alignment;
static constexpr size_t max_allowed = (size_t(1) << (size_t(std::numeric_limits<size_t>::digits >> 1))) - alignment;
static constexpr size_t max_size = std::numeric_limits<size_t>::max() - alignment;
static constexpr size_t alignment_mask = alignment - 1;

//Indeed, we only need to check if max_size / nmemb < size
//max_allowed is for avoiding unnecessary DIV.
if (nmemb >= max_allowed && max_size / nmemb < size) {
return false;
}

if (size >= max_allowed &&
nmemb > 0 && max_size / nmemb < size) {
return false;
}

if (alignment == 0)
*out = size * nmemb;
else
*out = (size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);

return true;
}

Expand Down Expand Up @@ -297,11 +306,10 @@ class MiMallocAllocator : public IDeviceAllocator {

#endif


#ifdef USE_MIMALLOC
using TAllocator = MiMallocAllocator;
using TAllocator = MiMallocAllocator;
#else
using TAllocator = CPUAllocator;
using TAllocator = CPUAllocator;
#endif

using AllocatorPtr = std::shared_ptr<IAllocator>;
Expand Down
151 changes: 76 additions & 75 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision
} ONNXTensorElementDataType;

// Synced with onnx TypeProto oneof
Expand Down Expand Up @@ -193,6 +193,7 @@ struct OrtCustomOp;
typedef struct OrtCustomOp OrtCustomOp;

typedef enum OrtAllocatorType {
Invalid = -1,
OrtDeviceAllocator = 0,
OrtArenaAllocator = 1
} OrtAllocatorType;
Expand Down Expand Up @@ -234,14 +235,14 @@ struct OrtApi {
const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

/**
* \param out Should be freed by `OrtReleaseEnv` after use
*/
* \param out Should be freed by `OrtReleaseEnv` after use
*/
OrtStatus*(ORT_API_CALL* CreateEnv)(OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out)
NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

/**
* \param out Should be freed by `OrtReleaseEnv` after use
*/
* \param out Should be freed by `OrtReleaseEnv` after use
*/
OrtStatus*(ORT_API_CALL* CreateEnvWithCustomLogger)(OrtLoggingFunction logging_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
_In_ const char* logid,
Expand All @@ -268,8 +269,8 @@ struct OrtApi {
_In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** output)NO_EXCEPTION;

/**
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
*/
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
*/
OrtStatus*(ORT_API_CALL* CreateSessionOptions)(_Outptr_ OrtSessionOptions** options)NO_EXCEPTION;

// Set filepath to save optimized model after graph level transformations.
Expand Down Expand Up @@ -325,36 +326,36 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* CreateCustomOpDomain)(_In_ const char* domain, _Outptr_ OrtCustomOpDomain** out)NO_EXCEPTION;

/*
* Add custom ops to the OrtCustomOpDomain
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released
*/
* Add custom ops to the OrtCustomOpDomain
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released
*/
OrtStatus*(ORT_API_CALL* CustomOpDomain_Add)(_Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op)NO_EXCEPTION;

/*
* Add a custom op domain to the OrtSessionOptions
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released
*/
* Add a custom op domain to the OrtSessionOptions
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released
*/
OrtStatus*(ORT_API_CALL* AddCustomOpDomain)(_Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain)NO_EXCEPTION;

/*
* Loads a DLL named 'library_path' and looks for this entry point:
* OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api);
* It then passes in the provided session options to this function along with the api base.
* The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in
* session options are destroyed, or if an error occurs and it is non null.
* Loads a DLL named 'library_path' and looks for this entry point:
* OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api);
* It then passes in the provided session options to this function along with the api base.
* The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in
* session options are destroyed, or if an error occurs and it is non null.
*/
OrtStatus*(ORT_API_CALL* RegisterCustomOpsLibrary)(_Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle)NO_EXCEPTION;

/**
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*/
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*/

OrtStatus*(ORT_API_CALL* SessionGetInputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* SessionGetOutputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION;
Expand Down Expand Up @@ -432,39 +433,39 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Outptr_ void** out)NO_EXCEPTION;

/**
* \param value A tensor created from CreateTensor... function.
* \param s each A string array. Each string in this array must be null terminated.
* \param s_len length of s
*/
* \param value A tensor created from OrtCreateTensor... function.
* \param s each A string array. Each string in this array must be null terminated.
* \param s_len length of s
*/
OrtStatus*(ORT_API_CALL* FillStringTensor)(_Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len)NO_EXCEPTION;

/**
* \param value A tensor created from CreateTensor... function.
* \param len total data length, not including the trailing '\0' chars.
*/
* \param value A tensor created from OrtCreateTensor... function.
* \param len total data length, not including the trailing '\0' chars.
*/
OrtStatus*(ORT_API_CALL* GetStringTensorDataLength)(_In_ const OrtValue* value, _Out_ size_t* len)NO_EXCEPTION;

/**
* \param s string contents. Each string is NOT null-terminated.
* \param value A tensor created from CreateTensor... function.
* \param s_len total data length, get it from GetStringTensorDataLength
*/
* \param s string contents. Each string is NOT null-terminated.
* \param value A tensor created from OrtCreateTensor... function.
* \param s_len total data length, get it from OrtGetStringTensorDataLength
*/
OrtStatus*(ORT_API_CALL* GetStringTensorContent)(_In_ const OrtValue* value, _Out_ void* s, size_t s_len,
_Out_ size_t* offsets, size_t offsets_len)NO_EXCEPTION;

/**
* Don't free the 'out' value
*/
* Don't free the 'out' value
*/
OrtStatus*(ORT_API_CALL* CastTypeInfoToTensorInfo)(_In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION;

/**
* Return OnnxType from OrtTypeInfo
*/
* Return OnnxType from OrtTypeInfo
*/
OrtStatus*(ORT_API_CALL* GetOnnxTypeFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ enum ONNXType* out)NO_EXCEPTION;

/**
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/
OrtStatus*(ORT_API_CALL* CreateTensorTypeAndShapeInfo)(_Outptr_ OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION;

OrtStatus*(ORT_API_CALL* SetTensorElementType)(_Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type)NO_EXCEPTION;
Expand Down Expand Up @@ -592,36 +593,36 @@ struct OrtApi {
_Outptr_ OrtValue** out)NO_EXCEPTION;

/**
* Construct OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization. OrtValue in this case would contain
* an internal representation of the Opaque type. Opaque types are distinguished between
* each other by two strings 1) domain and 2) type name. The combination of the two
* must be unique, so the type representation is properly identified internally. The combination
* must be properly registered from within ORT at both compile/run time or by another API.
*
* To construct the OrtValue pass domain and type names, also a pointer to a data container
* the type of which must be know to both ORT and the client program. That data container may or may
* not match the internal representation of the Opaque type. The sizeof(data_container) is passed for
* verification purposes.
*
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/
* Construct OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization. OrtValue in this case would contain
* an internal representation of the Opaque type. Opaque types are distinguished between
* each other by two strings 1) domain and 2) type name. The combination of the two
* must be unique, so the type representation is properly identified internally. The combination
* must be properly registered from within ORT at both compile/run time or by another API.
*
* To construct the OrtValue pass domain and type names, also a pointer to a data container
* the type of which must be know to both ORT and the client program. That data container may or may
* not match the internal representation of the Opaque type. The sizeof(data_container) is passed for
* verification purposes.
*
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/
OrtStatus*(ORT_API_CALL* CreateOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name,
_In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out)NO_EXCEPTION;

/**
* Fetch data from an OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization.
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/
* Fetch data from an OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization.
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/

OrtStatus*(ORT_API_CALL* GetOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name,
_In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size)NO_EXCEPTION;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr
}

ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out) {
*out = ptr->type;
*out = ptr->alloc_type;
return nullptr;
}

Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/core/framework/execution_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,25 @@ class ExecutionProviders {

// maps for fast lookup of an index into exec_providers_
std::unordered_map<std::string, size_t> provider_idx_map_;

// currently the allocator type is an implementation detail and we don't make any behavioral choices based on it,
// so exclude it from the key comparison for allocator_idx_map_.
// we also don't expect to have two allocators with the same name, one using an arena and one not.
struct OrtMemoryInfoLessThanIgnoreAllocType {
bool operator()(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) const {
/*if (lhs.alloc_type != rhs.alloc_type)
return lhs.alloc_type < rhs.alloc_type;*/
if (lhs.mem_type != rhs.mem_type)
return lhs.mem_type < rhs.mem_type;
if (lhs.id != rhs.id)
return lhs.id < rhs.id;

return strcmp(lhs.name, rhs.name) < 0;
}
};

// using std::map as OrtMemoryInfo would need a custom hash function to be used with unordered_map,
// and as this isn't performance critical it's not worth the maintenance overhead of adding one.
std::map<OrtMemoryInfo, size_t> allocator_idx_map_;
std::map<OrtMemoryInfo, size_t, OrtMemoryInfoLessThanIgnoreAllocType> allocator_idx_map_;
};
} // namespace onnxruntime
Loading

0 comments on commit 98cb41a

Please sign in to comment.