Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore allocator type in ExecutionProviders allocator map. Make default initialization of OrtMemoryInfo more clearly invalid. #2768

Merged
merged 7 commits into from
Jan 14, 2020
40 changes: 24 additions & 16 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,27 @@ struct OrtMemoryInfo {

// use string for name, so we could have customized allocator in execution provider.
const char* name;
int id;
OrtMemType mem_type;
OrtAllocatorType type;
int id = -1;
OrtMemType mem_type = OrtMemTypeDefault;
OrtAllocatorType alloc_type = Invalid;
OrtDevice device;

constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault)
constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0,
OrtMemType mem_type_ = OrtMemTypeDefault)
#if (defined(__GNUC__) || defined(__clang__))
__attribute__((nonnull))
#endif
: name(name_),
id(id_),
mem_type(mem_type_),
type(type_),
alloc_type(type_),
device(device_) {
}

// To make OrtMemoryInfo become a valid key in std map
inline bool operator<(const OrtMemoryInfo& other) const {
if (type != other.type)
return type < other.type;
bool operator<(const OrtMemoryInfo& other) const {
if (alloc_type != other.alloc_type)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definition is implicitly inline, so 'inline' keyword here is unnecessary.

return alloc_type < other.alloc_type;
if (mem_type != other.mem_type)
return mem_type < other.mem_type;
if (id != other.id)
Expand All @@ -113,20 +114,22 @@ struct OrtMemoryInfo {
return strcmp(name, other.name) < 0;
}

inline std::string ToString() const {
std::string ToString() const {
std::ostringstream ostr;
ostr << "OrtMemoryInfo: ["
<< " name:" << name
<< " id:" << id
<< " mem_type:" << mem_type
<< " type:" << type
<< " alloc_type:" << alloc_type
<< "]";
return ostr.str();
}
};

inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) {
return left.mem_type == other.mem_type && left.type == other.type && left.id == other.id &&
return left.mem_type == other.mem_type &&
left.alloc_type == other.alloc_type &&
left.id == other.id &&
strcmp(left.name, other.name) == 0;
}

Expand Down Expand Up @@ -213,9 +216,11 @@ class IAllocator {
if (!std::is_void<T>::value) {
// sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't
// reachable if T is void. use std::conditional to 'use' void* in the sizeof call
if (!CalcMemSizeForArray(count_or_bytes, sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type),
if (!CalcMemSizeForArray(count_or_bytes,
sizeof(typename std::conditional<std::is_void<T>::value, void*, T>::type),
&alloc_size)) return nullptr;
}

return IAllocatorUniquePtr<T>{
static_cast<T*>(allocator->Alloc(alloc_size)), // allocate
[=](T* ptr) { allocator->Free(ptr); }}; // capture IAllocator so it's always valid, and use as deleter
Expand All @@ -224,22 +229,26 @@ class IAllocator {

template <size_t alignment>
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept {
static constexpr size_t max_allowed = (static_cast<size_t>(1) << (static_cast<size_t>(std::numeric_limits<size_t>::digits >> 1))) - alignment;
static constexpr size_t max_allowed = (size_t(1) << (size_t(std::numeric_limits<size_t>::digits >> 1))) - alignment;
static constexpr size_t max_size = std::numeric_limits<size_t>::max() - alignment;
static constexpr size_t alignment_mask = alignment - 1;

//Indeed, we only need to check if max_size / nmemb < size
//max_allowed is for avoiding unnecessary DIV.
if (nmemb >= max_allowed && max_size / nmemb < size) {
return false;
}

if (size >= max_allowed &&
nmemb > 0 && max_size / nmemb < size) {
return false;
}

if (alignment == 0)
*out = size * nmemb;
else
*out = (size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);

return true;
}

Expand Down Expand Up @@ -297,11 +306,10 @@ class MiMallocAllocator : public IDeviceAllocator {

#endif


#ifdef USE_MIMALLOC
using TAllocator = MiMallocAllocator;
using TAllocator = MiMallocAllocator;
#else
using TAllocator = CPUAllocator;
using TAllocator = CPUAllocator;
#endif

using AllocatorPtr = std::shared_ptr<IAllocator>;
Expand Down
151 changes: 76 additions & 75 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ typedef enum ONNXTensorElementDataType {
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision
} ONNXTensorElementDataType;

// Synced with onnx TypeProto oneof
Expand Down Expand Up @@ -193,6 +193,7 @@ struct OrtCustomOp;
typedef struct OrtCustomOp OrtCustomOp;

typedef enum OrtAllocatorType {
Invalid = -1,
OrtDeviceAllocator = 0,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only 'real' change here. The rest is auto formatting by clang including replacing tabs with spaces

OrtArenaAllocator = 1
} OrtAllocatorType;
Expand Down Expand Up @@ -234,14 +235,14 @@ struct OrtApi {
const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

/**
* \param out Should be freed by `OrtReleaseEnv` after use
*/
* \param out Should be freed by `OrtReleaseEnv` after use
*/
OrtStatus*(ORT_API_CALL* CreateEnv)(OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out)
NO_EXCEPTION ORT_ALL_ARGS_NONNULL;

/**
* \param out Should be freed by `OrtReleaseEnv` after use
*/
* \param out Should be freed by `OrtReleaseEnv` after use
*/
OrtStatus*(ORT_API_CALL* CreateEnvWithCustomLogger)(OrtLoggingFunction logging_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
_In_ const char* logid,
Expand All @@ -268,8 +269,8 @@ struct OrtApi {
_In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** output)NO_EXCEPTION;

/**
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
*/
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
*/
OrtStatus*(ORT_API_CALL* CreateSessionOptions)(_Outptr_ OrtSessionOptions** options)NO_EXCEPTION;

// Set filepath to save optimized model after graph level transformations.
Expand Down Expand Up @@ -325,36 +326,36 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* CreateCustomOpDomain)(_In_ const char* domain, _Outptr_ OrtCustomOpDomain** out)NO_EXCEPTION;

/*
* Add custom ops to the OrtCustomOpDomain
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released
*/
* Add custom ops to the OrtCustomOpDomain
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released
*/
OrtStatus*(ORT_API_CALL* CustomOpDomain_Add)(_Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op)NO_EXCEPTION;

/*
* Add a custom op domain to the OrtSessionOptions
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released
*/
* Add a custom op domain to the OrtSessionOptions
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released
*/
OrtStatus*(ORT_API_CALL* AddCustomOpDomain)(_Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain)NO_EXCEPTION;

/*
* Loads a DLL named 'library_path' and looks for this entry point:
* OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api);
* It then passes in the provided session options to this function along with the api base.
* The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in
* session options are destroyed, or if an error occurs and it is non null.
* Loads a DLL named 'library_path' and looks for this entry point:
* OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api);
* It then passes in the provided session options to this function along with the api base.
* The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in
* session options are destroyed, or if an error occurs and it is non null.
*/
OrtStatus*(ORT_API_CALL* RegisterCustomOpsLibrary)(_Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle)NO_EXCEPTION;

/**
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*/
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*/

OrtStatus*(ORT_API_CALL* SessionGetInputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* SessionGetOutputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION;
Expand Down Expand Up @@ -432,39 +433,39 @@ struct OrtApi {
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Outptr_ void** out)NO_EXCEPTION;

/**
* \param value A tensor created from CreateTensor... function.
* \param s each A string array. Each string in this array must be null terminated.
* \param s_len length of s
*/
* \param value A tensor created from OrtCreateTensor... function.
* \param s each A string array. Each string in this array must be null terminated.
* \param s_len length of s
*/
OrtStatus*(ORT_API_CALL* FillStringTensor)(_Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len)NO_EXCEPTION;

/**
* \param value A tensor created from CreateTensor... function.
* \param len total data length, not including the trailing '\0' chars.
*/
* \param value A tensor created from OrtCreateTensor... function.
* \param len total data length, not including the trailing '\0' chars.
*/
OrtStatus*(ORT_API_CALL* GetStringTensorDataLength)(_In_ const OrtValue* value, _Out_ size_t* len)NO_EXCEPTION;

/**
* \param s string contents. Each string is NOT null-terminated.
* \param value A tensor created from CreateTensor... function.
* \param s_len total data length, get it from GetStringTensorDataLength
*/
* \param s string contents. Each string is NOT null-terminated.
* \param value A tensor created from OrtCreateTensor... function.
* \param s_len total data length, get it from OrtGetStringTensorDataLength
*/
OrtStatus*(ORT_API_CALL* GetStringTensorContent)(_In_ const OrtValue* value, _Out_ void* s, size_t s_len,
_Out_ size_t* offsets, size_t offsets_len)NO_EXCEPTION;

/**
* Don't free the 'out' value
*/
* Don't free the 'out' value
*/
OrtStatus*(ORT_API_CALL* CastTypeInfoToTensorInfo)(_In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION;

/**
* Return OnnxType from OrtTypeInfo
*/
* Return OnnxType from OrtTypeInfo
*/
OrtStatus*(ORT_API_CALL* GetOnnxTypeFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ enum ONNXType* out)NO_EXCEPTION;

/**
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo
*/
OrtStatus*(ORT_API_CALL* CreateTensorTypeAndShapeInfo)(_Outptr_ OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION;

OrtStatus*(ORT_API_CALL* SetTensorElementType)(_Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type)NO_EXCEPTION;
Expand Down Expand Up @@ -592,36 +593,36 @@ struct OrtApi {
_Outptr_ OrtValue** out)NO_EXCEPTION;

/**
* Construct OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization. OrtValue in this case would contain
* an internal representation of the Opaque type. Opaque types are distinguished between
* each other by two strings 1) domain and 2) type name. The combination of the two
* must be unique, so the type representation is properly identified internally. The combination
* must be properly registered from within ORT at both compile/run time or by another API.
*
* To construct the OrtValue pass domain and type names, also a pointer to a data container
* the type of which must be know to both ORT and the client program. That data container may or may
* not match the internal representation of the Opaque type. The sizeof(data_container) is passed for
* verification purposes.
*
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/
* Construct OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization. OrtValue in this case would contain
* an internal representation of the Opaque type. Opaque types are distinguished between
* each other by two strings 1) domain and 2) type name. The combination of the two
* must be unique, so the type representation is properly identified internally. The combination
* must be properly registered from within ORT at both compile/run time or by another API.
*
* To construct the OrtValue pass domain and type names, also a pointer to a data container
* the type of which must be know to both ORT and the client program. That data container may or may
* not match the internal representation of the Opaque type. The sizeof(data_container) is passed for
* verification purposes.
*
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/
OrtStatus*(ORT_API_CALL* CreateOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name,
_In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out)NO_EXCEPTION;

/**
* Fetch data from an OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization.
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/
* Fetch data from an OrtValue that contains a value of non-standard type created for
* experiments or while awaiting standardization.
* \domain_name - domain name for the Opaque type, null terminated.
* \type_name - type name for the Opaque type, null terminated.
* \data_contianer - data to populate OrtValue
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected
* data_container size internally.
*/

OrtStatus*(ORT_API_CALL* GetOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name,
_In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size)NO_EXCEPTION;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr
}

ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out) {
*out = ptr->type;
*out = ptr->alloc_type;
return nullptr;
}

Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/core/framework/execution_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,25 @@ class ExecutionProviders {

// maps for fast lookup of an index into exec_providers_
std::unordered_map<std::string, size_t> provider_idx_map_;

// currently the allocator type is an implementation detail and we don't make any behavioral choices based on it,
// so exclude it from the key comparison for allocator_idx_map_.
// we also don't expect to have two allocators with the same name, one using an arena and one not.
struct OrtMemoryInfoLessThanIgnoreAllocType {
bool operator()(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Any reason to not use OrtDevice for comparison?
  2. Also, it looks like storing id in OrtMemoryInfo seems redundant given that we already have OrtDevice that encapsulates the device id.

Copy link
Contributor Author

@skottmckay skottmckay Jan 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Do you mean OrtMemoryInfo? I used a custom comparator here as it may not always be the case that you want to ignore the allocator type.
  2. Not sure it's redundant. Couldn't you have multiple different allocators using the same device, so the 'id' in OrtMemoryInfo isn't necessarily 1:1 with the device id?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. No, I mean OrtMemoryInfo contains a member called 'device'. Any reason to not consider that for the comparison.
  2. The 'id' in OrtMemoryInfo was meant to represent a device id and then someone introduced OrtDevice which also contains a member called DeviceId device_id. Hence the confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, the code was copied from OrtMemoryInfo::operator< and 'device' is ignored there. Leaving potential removal of OrtMemoryInfo::id as a separate PR as it may change public interfaces.

/*if (lhs.alloc_type != rhs.alloc_type)
return lhs.alloc_type < rhs.alloc_type;*/
if (lhs.mem_type != rhs.mem_type)
return lhs.mem_type < rhs.mem_type;
if (lhs.id != rhs.id)
return lhs.id < rhs.id;

return strcmp(lhs.name, rhs.name) < 0;
}
};

// using std::map as OrtMemoryInfo would need a custom hash function to be used with unordered_map,
// and as this isn't performance critical it's not worth the maintenance overhead of adding one.
std::map<OrtMemoryInfo, size_t> allocator_idx_map_;
std::map<OrtMemoryInfo, size_t, OrtMemoryInfoLessThanIgnoreAllocType> allocator_idx_map_;
};
} // namespace onnxruntime
Loading