diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index f6a31494768c8..1fdcdfc02e313 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -85,26 +85,27 @@ struct OrtMemoryInfo { // use string for name, so we could have customized allocator in execution provider. const char* name; - int id; - OrtMemType mem_type; - OrtAllocatorType type; + int id = -1; + OrtMemType mem_type = OrtMemTypeDefault; + OrtAllocatorType alloc_type = Invalid; OrtDevice device; - constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, OrtMemType mem_type_ = OrtMemTypeDefault) + constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, + OrtMemType mem_type_ = OrtMemTypeDefault) #if (defined(__GNUC__) || defined(__clang__)) __attribute__((nonnull)) #endif : name(name_), id(id_), mem_type(mem_type_), - type(type_), + alloc_type(type_), device(device_) { } // To make OrtMemoryInfo become a valid key in std map - inline bool operator<(const OrtMemoryInfo& other) const { - if (type != other.type) - return type < other.type; + bool operator<(const OrtMemoryInfo& other) const { + if (alloc_type != other.alloc_type) + return alloc_type < other.alloc_type; if (mem_type != other.mem_type) return mem_type < other.mem_type; if (id != other.id) @@ -113,20 +114,22 @@ struct OrtMemoryInfo { return strcmp(name, other.name) < 0; } - inline std::string ToString() const { + std::string ToString() const { std::ostringstream ostr; ostr << "OrtMemoryInfo: [" << " name:" << name << " id:" << id << " mem_type:" << mem_type - << " type:" << type + << " alloc_type:" << alloc_type << "]"; return ostr.str(); } }; inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) { - return left.mem_type == other.mem_type && left.type == other.type && left.id == other.id && + return left.mem_type == other.mem_type && + left.alloc_type == other.alloc_type && + left.id == other.id && strcmp(left.name, other.name) == 0; } @@ -213,9 +216,11 @@ class IAllocator { if (!std::is_void::value) { // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't // reachable if T is void. use std::conditional to 'use' void* in the sizeof call - if (!CalcMemSizeForArray(count_or_bytes, sizeof(typename std::conditional::value, void*, T>::type), + if (!CalcMemSizeForArray(count_or_bytes, + sizeof(typename std::conditional::value, void*, T>::type), &alloc_size)) return nullptr; } + return IAllocatorUniquePtr{ static_cast(allocator->Alloc(alloc_size)), // allocate [=](T* ptr) { allocator->Free(ptr); }}; // capture IAllocator so it's always valid, and use as deleter @@ -224,22 +229,26 @@ class IAllocator { template bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept { - static constexpr size_t max_allowed = (static_cast(1) << (static_cast(std::numeric_limits::digits >> 1))) - alignment; + static constexpr size_t max_allowed = (size_t(1) << (size_t(std::numeric_limits::digits >> 1))) - alignment; static constexpr size_t max_size = std::numeric_limits::max() - alignment; static constexpr size_t alignment_mask = alignment - 1; + //Indeed, we only need to check if max_size / nmemb < size //max_allowed is for avoiding unnecessary DIV. if (nmemb >= max_allowed && max_size / nmemb < size) { return false; } + if (size >= max_allowed && nmemb > 0 && max_size / nmemb < size) { return false; } + if (alignment == 0) *out = size * nmemb; else *out = (size * nmemb + alignment_mask) & ~static_cast(alignment_mask); + return true; } @@ -297,11 +306,10 @@ class MiMallocAllocator : public IDeviceAllocator { #endif - #ifdef USE_MIMALLOC - using TAllocator = MiMallocAllocator; +using TAllocator = MiMallocAllocator; #else - using TAllocator = CPUAllocator; +using TAllocator = CPUAllocator; #endif using AllocatorPtr = std::shared_ptr; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index ed07c3bbb2d87..247d204f9b0de 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -84,12 +84,12 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, - ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 // Non-IEEE floating-point format based on IEEE754 single-precision } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof @@ -193,6 +193,7 @@ struct OrtCustomOp; typedef struct OrtCustomOp OrtCustomOp; typedef enum OrtAllocatorType { + Invalid = -1, OrtDeviceAllocator = 0, OrtArenaAllocator = 1 } OrtAllocatorType; @@ -234,14 +235,14 @@ struct OrtApi { const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** - * \param out Should be freed by `OrtReleaseEnv` after use - */ + * \param out Should be freed by `OrtReleaseEnv` after use + */ OrtStatus*(ORT_API_CALL* CreateEnv)(OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; /** - * \param out Should be freed by `OrtReleaseEnv` after use - */ + * \param out Should be freed by `OrtReleaseEnv` after use + */ OrtStatus*(ORT_API_CALL* CreateEnvWithCustomLogger)(OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, @@ -268,8 +269,8 @@ struct OrtApi { _In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** output)NO_EXCEPTION; /** - * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use - */ + * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use + */ OrtStatus*(ORT_API_CALL* CreateSessionOptions)(_Outptr_ OrtSessionOptions** options)NO_EXCEPTION; // Set filepath to save optimized model after graph level transformations. @@ -325,36 +326,36 @@ struct OrtApi { OrtStatus*(ORT_API_CALL* CreateCustomOpDomain)(_In_ const char* domain, _Outptr_ OrtCustomOpDomain** out)NO_EXCEPTION; /* - * Add custom ops to the OrtCustomOpDomain - * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released - */ + * Add custom ops to the OrtCustomOpDomain + * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released + */ OrtStatus*(ORT_API_CALL* CustomOpDomain_Add)(_Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op)NO_EXCEPTION; /* - * Add a custom op domain to the OrtSessionOptions - * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released - */ + * Add a custom op domain to the OrtSessionOptions + * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released + */ OrtStatus*(ORT_API_CALL* AddCustomOpDomain)(_Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain)NO_EXCEPTION; /* - * Loads a DLL named 'library_path' and looks for this entry point: - * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); - * It then passes in the provided session options to this function along with the api base. - * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in - * session options are destroyed, or if an error occurs and it is non null. + * Loads a DLL named 'library_path' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in + * session options are destroyed, or if an error occurs and it is non null. */ OrtStatus*(ORT_API_CALL* RegisterCustomOpsLibrary)(_Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle)NO_EXCEPTION; /** - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session: - * OrtSessionOptionsAppendExecutionProvider_CPU - * OrtSessionOptionsAppendExecutionProvider_CUDA - * OrtSessionOptionsAppendExecutionProvider_ - * The order they care called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. - */ + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session: + * OrtSessionOptionsAppendExecutionProvider_CPU + * OrtSessionOptionsAppendExecutionProvider_CUDA + * OrtSessionOptionsAppendExecutionProvider_ + * The order they care called indicates the preference order as well. In other words call this method + * on your most preferred execution provider first followed by the less preferred ones. + * If none are called Ort will use its internal CPU execution provider. + */ OrtStatus*(ORT_API_CALL* SessionGetInputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* SessionGetOutputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; @@ -432,39 +433,39 @@ struct OrtApi { OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Outptr_ void** out)NO_EXCEPTION; /** - * \param value A tensor created from CreateTensor... function. - * \param s each A string array. Each string in this array must be null terminated. - * \param s_len length of s - */ + * \param value A tensor created from OrtCreateTensor... function. + * \param s each A string array. Each string in this array must be null terminated. + * \param s_len length of s + */ OrtStatus*(ORT_API_CALL* FillStringTensor)(_Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len)NO_EXCEPTION; /** - * \param value A tensor created from CreateTensor... function. - * \param len total data length, not including the trailing '\0' chars. - */ + * \param value A tensor created from OrtCreateTensor... function. + * \param len total data length, not including the trailing '\0' chars. + */ OrtStatus*(ORT_API_CALL* GetStringTensorDataLength)(_In_ const OrtValue* value, _Out_ size_t* len)NO_EXCEPTION; /** - * \param s string contents. Each string is NOT null-terminated. - * \param value A tensor created from CreateTensor... function. - * \param s_len total data length, get it from GetStringTensorDataLength - */ + * \param s string contents. Each string is NOT null-terminated. + * \param value A tensor created from OrtCreateTensor... function. + * \param s_len total data length, get it from OrtGetStringTensorDataLength + */ OrtStatus*(ORT_API_CALL* GetStringTensorContent)(_In_ const OrtValue* value, _Out_ void* s, size_t s_len, _Out_ size_t* offsets, size_t offsets_len)NO_EXCEPTION; /** - * Don't free the 'out' value - */ + * Don't free the 'out' value + */ OrtStatus*(ORT_API_CALL* CastTypeInfoToTensorInfo)(_In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; /** - * Return OnnxType from OrtTypeInfo - */ + * Return OnnxType from OrtTypeInfo + */ OrtStatus*(ORT_API_CALL* GetOnnxTypeFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ enum ONNXType* out)NO_EXCEPTION; /** - * The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo - */ + * The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo + */ OrtStatus*(ORT_API_CALL* CreateTensorTypeAndShapeInfo)(_Outptr_ OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* SetTensorElementType)(_Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type)NO_EXCEPTION; @@ -592,36 +593,36 @@ struct OrtApi { _Outptr_ OrtValue** out)NO_EXCEPTION; /** - * Construct OrtValue that contains a value of non-standard type created for - * experiments or while awaiting standardization. OrtValue in this case would contain - * an internal representation of the Opaque type. Opaque types are distinguished between - * each other by two strings 1) domain and 2) type name. The combination of the two - * must be unique, so the type representation is properly identified internally. The combination - * must be properly registered from within ORT at both compile/run time or by another API. - * - * To construct the OrtValue pass domain and type names, also a pointer to a data container - * the type of which must be know to both ORT and the client program. That data container may or may - * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for - * verification purposes. - * - * \domain_name - domain name for the Opaque type, null terminated. - * \type_name - type name for the Opaque type, null terminated. - * \data_contianer - data to populate OrtValue - * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected - * data_container size internally. - */ + * Construct OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. OrtValue in this case would contain + * an internal representation of the Opaque type. Opaque types are distinguished between + * each other by two strings 1) domain and 2) type name. The combination of the two + * must be unique, so the type representation is properly identified internally. The combination + * must be properly registered from within ORT at both compile/run time or by another API. + * + * To construct the OrtValue pass domain and type names, also a pointer to a data container + * the type of which must be know to both ORT and the client program. That data container may or may + * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for + * verification purposes. + * + * \domain_name - domain name for the Opaque type, null terminated. + * \type_name - type name for the Opaque type, null terminated. + * \data_contianer - data to populate OrtValue + * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected + * data_container size internally. + */ OrtStatus*(ORT_API_CALL* CreateOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name, _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out)NO_EXCEPTION; /** - * Fetch data from an OrtValue that contains a value of non-standard type created for - * experiments or while awaiting standardization. - * \domain_name - domain name for the Opaque type, null terminated. - * \type_name - type name for the Opaque type, null terminated. - * \data_contianer - data to populate OrtValue - * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected - * data_container size internally. - */ + * Fetch data from an OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. + * \domain_name - domain name for the Opaque type, null terminated. + * \type_name - type name for the Opaque type, null terminated. + * \data_contianer - data to populate OrtValue + * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected + * data_container size internally. + */ OrtStatus*(ORT_API_CALL* GetOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size)NO_EXCEPTION; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index f58020df3f426..d3a567c4a207d 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -71,7 +71,7 @@ ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr } ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out) { - *out = ptr->type; + *out = ptr->alloc_type; return nullptr; } diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index ec7e2fedc9e2a..b17c3a61c046d 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -107,8 +107,25 @@ class ExecutionProviders { // maps for fast lookup of an index into exec_providers_ std::unordered_map provider_idx_map_; + + // currently the allocator type is an implementation detail and we don't make any behavioral choices based on it, + // so exclude it from the key comparison for allocator_idx_map_. + // we also don't expect to have two allocators with the same name, one using an arena and one not. + struct OrtMemoryInfoLessThanIgnoreAllocType { + bool operator()(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) const { + /*if (lhs.alloc_type != rhs.alloc_type) + return lhs.alloc_type < rhs.alloc_type;*/ + if (lhs.mem_type != rhs.mem_type) + return lhs.mem_type < rhs.mem_type; + if (lhs.id != rhs.id) + return lhs.id < rhs.id; + + return strcmp(lhs.name, rhs.name) < 0; + } + }; + // using std::map as OrtMemoryInfo would need a custom hash function to be used with unordered_map, // and as this isn't performance critical it's not worth the maintenance overhead of adding one. - std::map allocator_idx_map_; + std::map allocator_idx_map_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index dc906642d67e9..74e41c79a3f40 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -22,8 +22,7 @@ class OrtValueNameIdxMap { int Add(const std::string& name) { auto it = map_.find(name); if (it == map_.end()) { - int idx; - idx = ort_value_max_idx_++; + int idx = next_idx_++; map_.insert(it, {name, idx}); return idx; } @@ -43,7 +42,7 @@ class OrtValueNameIdxMap { } size_t Size() const { return map_.size(); }; - int MaxIdx() const { return ort_value_max_idx_; } + int MaxIdx() const { return next_idx_ - 1; } const_iterator begin() const noexcept { return map_.cbegin(); } const_iterator end() const noexcept { return map_.cend(); } @@ -51,8 +50,7 @@ class OrtValueNameIdxMap { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OrtValueNameIdxMap); - int ort_value_max_idx_ = 0; + int next_idx_ = 0; std::unordered_map map_; }; -using OrtValueNameIdxMap = OrtValueNameIdxMap; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/sequential_execution_plan.h b/onnxruntime/core/framework/sequential_execution_plan.h index 3e3dca68020cb..924b62d395432 100644 --- a/onnxruntime/core/framework/sequential_execution_plan.h +++ b/onnxruntime/core/framework/sequential_execution_plan.h @@ -33,7 +33,7 @@ struct AllocPlanPerValue { bool create_fence_if_async{false}; public: - AllocPlanPerValue() : location(CPU, OrtArenaAllocator) {} + AllocPlanPerValue() : location(CPU, Invalid) {} }; // SequentialExecutionPlan: This is the data that is produced by a static diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 3dd6ad72691fa..a80d6b6d47c85 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -150,7 +150,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st const Tensor& p_deserialize_tensor = tmp_ort_value.Get(); p_tensor = onnxruntime::make_unique(p_deserialize_tensor.DataType(), p_deserialize_tensor.Shape(), m.GetBuffer(), - m.GetAllocInfo()); + m.GetAllocInfo()); // TODO: does this function work for string tensor? Status copy_status = data_transfer_mgr.CopyTensor(p_deserialize_tensor, *p_tensor); if (d.f) d.f(d.param); @@ -177,7 +177,7 @@ common::Status SaveInitializedTensors(const Env& env, const std::basic_string 0, "OrtValue indexes should have been populated."); + ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); //1. first plan the memory const onnxruntime::InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.h b/onnxruntime/core/providers/cpu/cpu_execution_provider.h index c7d36d1f9bfa8..79099e65ba844 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.h +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.h @@ -38,9 +38,7 @@ class CPUExecutionProvider : public IExecutionProvider { ORT_UNUSED_PARAMETER(info); //JEMalloc already has memory pool, so just use device allocator. - InsertAllocator( - std::shared_ptr( - onnxruntime::make_unique(device_info.factory(0)))); + InsertAllocator(device_info.factory(0)); #else //Disable Arena allocator for x86_32 build because it may run into infinite loop when integer overflow happens #if defined(__amd64__) || defined(_M_AMD64) @@ -50,10 +48,7 @@ class CPUExecutionProvider : public IExecutionProvider { #else ORT_UNUSED_PARAMETER(info); #endif - InsertAllocator( - std::shared_ptr( - onnxruntime::make_unique(device_info.factory(0)))); - + InsertAllocator(device_info.factory(0)); #endif } diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 6fe5faf810ae9..9a5ded2146b7b 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -13,7 +13,13 @@ TEST(AllocatorTest, CPUAllocatorTest) { ASSERT_STREQ(cpu_arena->Info().name, CPU); EXPECT_EQ(cpu_arena->Info().id, 0); - EXPECT_EQ(cpu_arena->Info().type, OrtAllocatorType::OrtArenaAllocator); + + // arena is disabled for CPUExecutionProvider on x86 and JEMalloc +#if (defined(__amd64__) || defined(_M_AMD64)) && !defined(USE_JEMALLOC) + EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtArenaAllocator); +#else + EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtDeviceAllocator); +#endif size_t size = 1024; auto bytes = cpu_arena->Alloc(size); diff --git a/onnxruntime/test/framework/cuda/allocator_cuda_test.cc b/onnxruntime/test/framework/cuda/allocator_cuda_test.cc index 4bcca7ee7bd66..b0388b457a156 100644 --- a/onnxruntime/test/framework/cuda/allocator_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/allocator_cuda_test.cc @@ -11,8 +11,10 @@ namespace onnxruntime { namespace test { TEST(AllocatorTest, CUDAAllocatorTest) { int cuda_device_id = 0; - DeviceAllocatorRegistrationInfo default_memory_info({OrtMemTypeDefault, - [](int id) { return onnxruntime::make_unique(id, CUDA); }, std::numeric_limits::max()}); + DeviceAllocatorRegistrationInfo default_memory_info( + {OrtMemTypeDefault, + [](int id) { return onnxruntime::make_unique(id, CUDA); }, + std::numeric_limits::max()}); auto cuda_arena = CreateAllocator(default_memory_info, cuda_device_id); @@ -21,21 +23,23 @@ TEST(AllocatorTest, CUDAAllocatorTest) { EXPECT_STREQ(cuda_arena->Info().name, CUDA); EXPECT_EQ(cuda_arena->Info().id, cuda_device_id); EXPECT_EQ(cuda_arena->Info().mem_type, OrtMemTypeDefault); - EXPECT_EQ(cuda_arena->Info().type, OrtArenaAllocator); + EXPECT_EQ(cuda_arena->Info().alloc_type, OrtArenaAllocator); //test cuda allocation auto cuda_addr = cuda_arena->Alloc(size); EXPECT_TRUE(cuda_addr); - DeviceAllocatorRegistrationInfo pinned_memory_info({OrtMemTypeCPUOutput, - [](int) { return onnxruntime::make_unique(0, CUDA_PINNED); }, std::numeric_limits::max()}); + DeviceAllocatorRegistrationInfo pinned_memory_info( + {OrtMemTypeCPUOutput, + [](int) { return onnxruntime::make_unique(0, CUDA_PINNED); }, + std::numeric_limits::max()}); auto pinned_allocator = CreateAllocator(pinned_memory_info); EXPECT_STREQ(pinned_allocator->Info().name, CUDA_PINNED); EXPECT_EQ(pinned_allocator->Info().id, 0); EXPECT_EQ(pinned_allocator->Info().mem_type, OrtMemTypeCPUOutput); - EXPECT_EQ(pinned_allocator->Info().type, OrtArenaAllocator); + EXPECT_EQ(pinned_allocator->Info().alloc_type, OrtArenaAllocator); //test pinned allocation auto pinned_addr = pinned_allocator->Alloc(size); @@ -45,7 +49,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { EXPECT_STREQ(cpu_arena->Info().name, CPU); EXPECT_EQ(cpu_arena->Info().id, 0); EXPECT_EQ(cpu_arena->Info().mem_type, OrtMemTypeDefault); - EXPECT_EQ(cpu_arena->Info().type, OrtArenaAllocator); + EXPECT_EQ(cpu_arena->Info().alloc_type, OrtArenaAllocator); auto cpu_addr_a = cpu_arena->Alloc(size); EXPECT_TRUE(cpu_addr_a); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 411bba33d987a..2063852cdea84 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -86,8 +86,7 @@ class FuseExecutionProvider : public IExecutionProvider { DeviceAllocatorRegistrationInfo device_info({OrtMemTypeDefault, [](int) { return onnxruntime::make_unique(); }, std::numeric_limits::max()}); - InsertAllocator(std::shared_ptr( - onnxruntime::make_unique(device_info.factory(0)))); + InsertAllocator(device_info.factory(0)); } std::vector> @@ -776,8 +775,8 @@ static void TestBindHelper(const std::string& log_str, std::string s1; p_model->ToProto().SerializeToString(&s1); std::stringstream sstr(s1); - ASSERT_TRUE(session_object.Load(sstr).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_log_verbosity_level = so.session_log_verbosity_level; diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index d6b20a4b41e3a..6251b0259e719 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -136,7 +136,13 @@ TEST(TensorTest, EmptyTensorTest) { auto& location = t.Location(); ASSERT_STREQ(location.name, CPU); EXPECT_EQ(location.id, 0); - EXPECT_EQ(location.type, OrtAllocatorType::OrtArenaAllocator); + + // arena is disabled for CPUExecutionProvider on x86 and JEMalloc +#if (defined(__amd64__) || defined(_M_AMD64)) && !defined(USE_JEMALLOC) + EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtArenaAllocator); +#else + EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtDeviceAllocator); +#endif } TEST(TensorTest, StringTensorTest) {