-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ignore allocator type in ExecutionProviders allocator map. Make default initialization of OrtMemoryInfo more clearly invalid. #2768
Changes from all commits
2427b49
0aa0994
cac1d13
25c4475
16d1e9a
07aeef3
4309bfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,12 +84,12 @@ typedef enum ONNXTensorElementDataType { | |
ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components | ||
ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision | ||
} ONNXTensorElementDataType; | ||
|
||
// Synced with onnx TypeProto oneof | ||
|
@@ -193,6 +193,7 @@ struct OrtCustomOp; | |
typedef struct OrtCustomOp OrtCustomOp; | ||
|
||
typedef enum OrtAllocatorType { | ||
Invalid = -1, | ||
OrtDeviceAllocator = 0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the only 'real' change here. The rest is auto formatting by clang including replacing tabs with spaces |
||
OrtArenaAllocator = 1 | ||
} OrtAllocatorType; | ||
|
@@ -234,14 +235,14 @@ struct OrtApi { | |
const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
/** | ||
* \param out Should be freed by `OrtReleaseEnv` after use | ||
*/ | ||
* \param out Should be freed by `OrtReleaseEnv` after use | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CreateEnv)(OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) | ||
NO_EXCEPTION ORT_ALL_ARGS_NONNULL; | ||
|
||
/** | ||
* \param out Should be freed by `OrtReleaseEnv` after use | ||
*/ | ||
* \param out Should be freed by `OrtReleaseEnv` after use | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CreateEnvWithCustomLogger)(OrtLoggingFunction logging_function, | ||
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, | ||
_In_ const char* logid, | ||
|
@@ -268,8 +269,8 @@ struct OrtApi { | |
_In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** output)NO_EXCEPTION; | ||
|
||
/** | ||
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use | ||
*/ | ||
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CreateSessionOptions)(_Outptr_ OrtSessionOptions** options)NO_EXCEPTION; | ||
|
||
// Set filepath to save optimized model after graph level transformations. | ||
|
@@ -325,36 +326,36 @@ struct OrtApi { | |
OrtStatus*(ORT_API_CALL* CreateCustomOpDomain)(_In_ const char* domain, _Outptr_ OrtCustomOpDomain** out)NO_EXCEPTION; | ||
|
||
/* | ||
* Add custom ops to the OrtCustomOpDomain | ||
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released | ||
*/ | ||
* Add custom ops to the OrtCustomOpDomain | ||
* Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CustomOpDomain_Add)(_Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op)NO_EXCEPTION; | ||
|
||
/* | ||
* Add a custom op domain to the OrtSessionOptions | ||
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released | ||
*/ | ||
* Add a custom op domain to the OrtSessionOptions | ||
* Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released | ||
*/ | ||
OrtStatus*(ORT_API_CALL* AddCustomOpDomain)(_Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain)NO_EXCEPTION; | ||
|
||
/* | ||
* Loads a DLL named 'library_path' and looks for this entry point: | ||
* OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); | ||
* It then passes in the provided session options to this function along with the api base. | ||
* The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in | ||
* session options are destroyed, or if an error occurs and it is non null. | ||
* Loads a DLL named 'library_path' and looks for this entry point: | ||
* OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); | ||
* It then passes in the provided session options to this function along with the api base. | ||
* The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in | ||
* session options are destroyed, or if an error occurs and it is non null. | ||
*/ | ||
OrtStatus*(ORT_API_CALL* RegisterCustomOpsLibrary)(_Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle)NO_EXCEPTION; | ||
|
||
/** | ||
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these | ||
* functions to enable them in the session: | ||
* OrtSessionOptionsAppendExecutionProvider_CPU | ||
* OrtSessionOptionsAppendExecutionProvider_CUDA | ||
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...> | ||
* The order they care called indicates the preference order as well. In other words call this method | ||
* on your most preferred execution provider first followed by the less preferred ones. | ||
* If none are called Ort will use its internal CPU execution provider. | ||
*/ | ||
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these | ||
* functions to enable them in the session: | ||
* OrtSessionOptionsAppendExecutionProvider_CPU | ||
* OrtSessionOptionsAppendExecutionProvider_CUDA | ||
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...> | ||
* The order they care called indicates the preference order as well. In other words call this method | ||
* on your most preferred execution provider first followed by the less preferred ones. | ||
* If none are called Ort will use its internal CPU execution provider. | ||
*/ | ||
|
||
OrtStatus*(ORT_API_CALL* SessionGetInputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; | ||
OrtStatus*(ORT_API_CALL* SessionGetOutputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; | ||
|
@@ -432,39 +433,39 @@ struct OrtApi { | |
OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Outptr_ void** out)NO_EXCEPTION; | ||
|
||
/** | ||
* \param value A tensor created from CreateTensor... function. | ||
* \param s each A string array. Each string in this array must be null terminated. | ||
* \param s_len length of s | ||
*/ | ||
* \param value A tensor created from OrtCreateTensor... function. | ||
* \param s each A string array. Each string in this array must be null terminated. | ||
* \param s_len length of s | ||
*/ | ||
OrtStatus*(ORT_API_CALL* FillStringTensor)(_Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len)NO_EXCEPTION; | ||
|
||
/** | ||
* \param value A tensor created from CreateTensor... function. | ||
* \param len total data length, not including the trailing '\0' chars. | ||
*/ | ||
* \param value A tensor created from OrtCreateTensor... function. | ||
* \param len total data length, not including the trailing '\0' chars. | ||
*/ | ||
OrtStatus*(ORT_API_CALL* GetStringTensorDataLength)(_In_ const OrtValue* value, _Out_ size_t* len)NO_EXCEPTION; | ||
|
||
/** | ||
* \param s string contents. Each string is NOT null-terminated. | ||
* \param value A tensor created from CreateTensor... function. | ||
* \param s_len total data length, get it from GetStringTensorDataLength | ||
*/ | ||
* \param s string contents. Each string is NOT null-terminated. | ||
* \param value A tensor created from OrtCreateTensor... function. | ||
* \param s_len total data length, get it from OrtGetStringTensorDataLength | ||
*/ | ||
OrtStatus*(ORT_API_CALL* GetStringTensorContent)(_In_ const OrtValue* value, _Out_ void* s, size_t s_len, | ||
_Out_ size_t* offsets, size_t offsets_len)NO_EXCEPTION; | ||
|
||
/** | ||
* Don't free the 'out' value | ||
*/ | ||
* Don't free the 'out' value | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CastTypeInfoToTensorInfo)(_In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; | ||
|
||
/** | ||
* Return OnnxType from OrtTypeInfo | ||
*/ | ||
* Return OnnxType from OrtTypeInfo | ||
*/ | ||
OrtStatus*(ORT_API_CALL* GetOnnxTypeFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ enum ONNXType* out)NO_EXCEPTION; | ||
|
||
/** | ||
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo | ||
*/ | ||
* The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CreateTensorTypeAndShapeInfo)(_Outptr_ OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; | ||
|
||
OrtStatus*(ORT_API_CALL* SetTensorElementType)(_Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type)NO_EXCEPTION; | ||
|
@@ -592,36 +593,36 @@ struct OrtApi { | |
_Outptr_ OrtValue** out)NO_EXCEPTION; | ||
|
||
/** | ||
* Construct OrtValue that contains a value of non-standard type created for | ||
* experiments or while awaiting standardization. OrtValue in this case would contain | ||
* an internal representation of the Opaque type. Opaque types are distinguished between | ||
* each other by two strings 1) domain and 2) type name. The combination of the two | ||
* must be unique, so the type representation is properly identified internally. The combination | ||
* must be properly registered from within ORT at both compile/run time or by another API. | ||
* | ||
* To construct the OrtValue pass domain and type names, also a pointer to a data container | ||
* the type of which must be know to both ORT and the client program. That data container may or may | ||
* not match the internal representation of the Opaque type. The sizeof(data_container) is passed for | ||
* verification purposes. | ||
* | ||
* \domain_name - domain name for the Opaque type, null terminated. | ||
* \type_name - type name for the Opaque type, null terminated. | ||
* \data_contianer - data to populate OrtValue | ||
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected | ||
* data_container size internally. | ||
*/ | ||
* Construct OrtValue that contains a value of non-standard type created for | ||
* experiments or while awaiting standardization. OrtValue in this case would contain | ||
* an internal representation of the Opaque type. Opaque types are distinguished between | ||
* each other by two strings 1) domain and 2) type name. The combination of the two | ||
* must be unique, so the type representation is properly identified internally. The combination | ||
* must be properly registered from within ORT at both compile/run time or by another API. | ||
* | ||
* To construct the OrtValue pass domain and type names, also a pointer to a data container | ||
* the type of which must be know to both ORT and the client program. That data container may or may | ||
* not match the internal representation of the Opaque type. The sizeof(data_container) is passed for | ||
* verification purposes. | ||
* | ||
* \domain_name - domain name for the Opaque type, null terminated. | ||
* \type_name - type name for the Opaque type, null terminated. | ||
* \data_contianer - data to populate OrtValue | ||
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected | ||
* data_container size internally. | ||
*/ | ||
OrtStatus*(ORT_API_CALL* CreateOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name, | ||
_In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out)NO_EXCEPTION; | ||
|
||
/** | ||
* Fetch data from an OrtValue that contains a value of non-standard type created for | ||
* experiments or while awaiting standardization. | ||
* \domain_name - domain name for the Opaque type, null terminated. | ||
* \type_name - type name for the Opaque type, null terminated. | ||
* \data_contianer - data to populate OrtValue | ||
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected | ||
* data_container size internally. | ||
*/ | ||
* Fetch data from an OrtValue that contains a value of non-standard type created for | ||
* experiments or while awaiting standardization. | ||
* \domain_name - domain name for the Opaque type, null terminated. | ||
* \type_name - type name for the Opaque type, null terminated. | ||
* \data_contianer - data to populate OrtValue | ||
* \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected | ||
* data_container size internally. | ||
*/ | ||
|
||
OrtStatus*(ORT_API_CALL* GetOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name, | ||
_In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size)NO_EXCEPTION; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,8 +107,25 @@ class ExecutionProviders { | |
|
||
// maps for fast lookup of an index into exec_providers_ | ||
std::unordered_map<std::string, size_t> provider_idx_map_; | ||
|
||
// currently the allocator type is an implementation detail and we don't make any behavioral choices based on it, | ||
// so exclude it from the key comparison for allocator_idx_map_. | ||
// we also don't expect to have two allocators with the same name, one using an arena and one not. | ||
struct OrtMemoryInfoLessThanIgnoreAllocType { | ||
bool operator()(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed, the code was copied from OrtMemoryInfo::operator< and 'device' is ignored there. Leaving potential removal of OrtMemoryInfo::id as a separate PR as it may change public interfaces. |
||
/*if (lhs.alloc_type != rhs.alloc_type) | ||
return lhs.alloc_type < rhs.alloc_type;*/ | ||
if (lhs.mem_type != rhs.mem_type) | ||
return lhs.mem_type < rhs.mem_type; | ||
if (lhs.id != rhs.id) | ||
return lhs.id < rhs.id; | ||
|
||
return strcmp(lhs.name, rhs.name) < 0; | ||
} | ||
}; | ||
|
||
// using std::map as OrtMemoryInfo would need a custom hash function to be used with unordered_map, | ||
// and as this isn't performance critical it's not worth the maintenance overhead of adding one. | ||
std::map<OrtMemoryInfo, size_t> allocator_idx_map_; | ||
std::map<OrtMemoryInfo, size_t, OrtMemoryInfoLessThanIgnoreAllocType> allocator_idx_map_; | ||
}; | ||
} // namespace onnxruntime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definition is implicitly inline, so 'inline' keyword here is unnecessary.