This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from dmlc/storage
Storage
- Loading branch information
Showing
16 changed files
with
730 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,3 +49,7 @@ Debug | |
*.swp | ||
*.swo | ||
*.swn | ||
|
||
# Emacs | ||
.clang_complete | ||
.dir-locals.el |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,70 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* Copyright (c) 2015 by Contributors | ||
* \file storage.h | ||
* \brief the memory allocator that manages the memory across multiple devices | ||
* \brief Storage manager across multiple devices. | ||
*/ | ||
#ifndef MXNET_STORAGE_H_ | ||
#define MXNET_STORAGE_H_ | ||
#include "./base.h" | ||
#include "./context.h" | ||
|
||
#include <memory> | ||
#include "base.h" | ||
#include "context.h" | ||
|
||
namespace mxnet { | ||
/*! \brief memory allocator of storage */ | ||
class StorageManager { | ||
|
||
/*! | ||
* \brief Storage manager across multiple devices. | ||
*/ | ||
class Storage { | ||
public: | ||
/*! | ||
* \brief storage handle the represents storage information | ||
* \brief Storage handle. | ||
*/ | ||
struct Handle { | ||
/*! \brief pointer to the data */ | ||
void *dptr; | ||
/*! \brief context information about device and deviceID */ | ||
Context ctx; | ||
/*! | ||
* \brief internal handle reserved for manager, | ||
* user should not change or use this | ||
* \brief Pointer to the data. | ||
*/ | ||
void *handle_; | ||
void* dptr; | ||
/*! | ||
* \brief Size of the storage. | ||
*/ | ||
size_t size; | ||
/*! | ||
* \brief Context information about device and ID. | ||
*/ | ||
Context ctx; | ||
}; | ||
/*! | ||
* \brief allocate a new contiguous memory for a given size | ||
* \param size the total size of memory in bytes | ||
* \param ctx context information about the device and deviceID | ||
* \return Handle struct | ||
* \brief Allocate a new contiguous memory for a given size. | ||
* \param size Total size of memory in bytes. | ||
* \param ctx Context information about the device and ID. | ||
* \return Handle struct. | ||
*/ | ||
virtual Handle Alloc(size_t size, Context ctx) = 0; | ||
Handle Alloc(size_t size, Context ctx); | ||
/*! | ||
* \brief free the space represened the handle | ||
* \param handle the handle to memory to be freed | ||
* \brief Free storage. | ||
* \param handle Handle struect. | ||
*/ | ||
virtual void Free(Handle handle) = 0; | ||
/*! \return storage manager singleton */ | ||
static StorageManager *Get(); | ||
}; // class StorageManager | ||
void Free(Handle handle); | ||
/*! | ||
* \brief Destructor. | ||
*/ | ||
~Storage(); | ||
/*! | ||
* \return Storage singleton. | ||
*/ | ||
static Storage* Get(); | ||
|
||
private: | ||
/*! | ||
* \brief Hidden constructors. | ||
*/ | ||
Storage(); | ||
struct Impl; | ||
std::unique_ptr<Impl> impl_; | ||
DISALLOW_COPY_AND_ASSIGN(Storage); | ||
}; // class Storage | ||
|
||
} // namespace mxnet | ||
|
||
#endif // MXNET_STORAGE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,9 @@ | |
#include <thread> | ||
#include <cstdio> | ||
|
||
/*! | ||
* \brief Common components. | ||
*/ | ||
namespace common { | ||
|
||
/*! | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file cuda_utils.h | ||
* \brief CUDA debugging utilities. | ||
*/ | ||
#ifndef MXNET_COMMON_CUDA_UTILS_H_ | ||
#define MXNET_COMMON_CUDA_UTILS_H_ | ||
|
||
#include <dmlc/logging.h> | ||
|
||
#if MXNET_USE_CUDA | ||
|
||
#include <cuda_runtime.h> | ||
#include <cublas_v2.h> | ||
#include <curand.h> | ||
|
||
namespace common { | ||
|
||
/*! | ||
* \brief CUDA utilities. | ||
*/ | ||
namespace cuda { | ||
|
||
/*! | ||
* \brief Get string representation of cuBLAS errors. | ||
* \param error The error. | ||
* \return String representation. | ||
*/ | ||
inline const char* CublasGetErrorString(cublasStatus_t error) { | ||
switch (error) { | ||
case CUBLAS_STATUS_SUCCESS: | ||
return "CUBLAS_STATUS_SUCCESS"; | ||
case CUBLAS_STATUS_NOT_INITIALIZED: | ||
return "CUBLAS_STATUS_NOT_INITIALIZED"; | ||
case CUBLAS_STATUS_ALLOC_FAILED: | ||
return "CUBLAS_STATUS_ALLOC_FAILED"; | ||
case CUBLAS_STATUS_INVALID_VALUE: | ||
return "CUBLAS_STATUS_INVALID_VALUE"; | ||
case CUBLAS_STATUS_ARCH_MISMATCH: | ||
return "CUBLAS_STATUS_ARCH_MISMATCH"; | ||
case CUBLAS_STATUS_MAPPING_ERROR: | ||
return "CUBLAS_STATUS_MAPPING_ERROR"; | ||
case CUBLAS_STATUS_EXECUTION_FAILED: | ||
return "CUBLAS_STATUS_EXECUTION_FAILED"; | ||
case CUBLAS_STATUS_INTERNAL_ERROR: | ||
return "CUBLAS_STATUS_INTERNAL_ERROR"; | ||
case CUBLAS_STATUS_NOT_SUPPORTED: | ||
return "CUBLAS_STATUS_NOT_SUPPORTED"; | ||
default: | ||
break; | ||
} | ||
return "Unknown cuBLAS status"; | ||
} | ||
|
||
/*! | ||
* \brief Get string representation of cuRAND errors. | ||
* \param status The status. | ||
* \return String representation. | ||
*/ | ||
inline const char* CurandGetErrorString(curandStatus_t status) { | ||
switch (status) { | ||
case CURAND_STATUS_SUCCESS: | ||
return "CURAND_STATUS_SUCCESS"; | ||
case CURAND_STATUS_VERSION_MISMATCH: | ||
return "CURAND_STATUS_VERSION_MISMATCH"; | ||
case CURAND_STATUS_NOT_INITIALIZED: | ||
return "CURAND_STATUS_NOT_INITIALIZED"; | ||
case CURAND_STATUS_ALLOCATION_FAILED: | ||
return "CURAND_STATUS_ALLOCATION_FAILED"; | ||
case CURAND_STATUS_TYPE_ERROR: | ||
return "CURAND_STATUS_TYPE_ERROR"; | ||
case CURAND_STATUS_OUT_OF_RANGE: | ||
return "CURAND_STATUS_OUT_OF_RANGE"; | ||
case CURAND_STATUS_LENGTH_NOT_MULTIPLE: | ||
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; | ||
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: | ||
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; | ||
case CURAND_STATUS_LAUNCH_FAILURE: | ||
return "CURAND_STATUS_LAUNCH_FAILURE"; | ||
case CURAND_STATUS_PREEXISTING_FAILURE: | ||
return "CURAND_STATUS_PREEXISTING_FAILURE"; | ||
case CURAND_STATUS_INITIALIZATION_FAILED: | ||
return "CURAND_STATUS_INITIALIZATION_FAILED"; | ||
case CURAND_STATUS_ARCH_MISMATCH: | ||
return "CURAND_STATUS_ARCH_MISMATCH"; | ||
case CURAND_STATUS_INTERNAL_ERROR: | ||
return "CURAND_STATUS_INTERNAL_ERROR"; | ||
} | ||
return "Unknown cuRAND status"; | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace common | ||
|
||
/*! | ||
* \brief Check CUDA error. | ||
* \param msg Message to print if an error occured. | ||
*/ | ||
#define CHECK_CUDA_ERROR(msg) \ | ||
{ \ | ||
cudaError_t e = cudaGetLastError(); \ | ||
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected CUDA call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for CUDA errors after invocation of the expression. | ||
*/ | ||
#define CUDA_CALL(func) \ | ||
{ \ | ||
cudaError_t e = (func); \ | ||
CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuBLAS call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuBLAS errors after invocation of the expression. | ||
*/ | ||
#define CUBLAS_CALL(func) \ | ||
{ \ | ||
cublasStatus_t e = (func); \ | ||
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ | ||
<< "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ | ||
} | ||
|
||
/*! | ||
* \brief Protected cuRAND call. | ||
* \param func Expression to call. | ||
* | ||
* It checks for cuRAND errors after invocation of the expression. | ||
*/ | ||
#define CURAND_CALL(func) \ | ||
{ \ | ||
curandStatus_t e = (func); \ | ||
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ | ||
<< "cuRAND: " << common::cuda::CurandGetErrorString(e); \ | ||
} | ||
|
||
#endif // MXNET_USE_CUDA | ||
|
||
#if MXNET_USE_CUDNN | ||
|
||
#include <cudnn.h> | ||
|
||
#define CUDNN_CALL(func) \ | ||
{ \ | ||
cudnnStatus_t e = (func); \ | ||
CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ | ||
} | ||
|
||
#endif // MXNET_USE_CUDNN | ||
|
||
#endif // MXNET_COMMON_CUDA_UTILS_H_ |
Oops, something went wrong.