Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #17 from dmlc/storage
Browse files Browse the repository at this point in the history
Storage
  • Loading branch information
hotpxl committed Aug 16, 2015
2 parents df28352 + 3a8e071 commit 9f4d31c
Show file tree
Hide file tree
Showing 16 changed files with 730 additions and 65 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@ Debug
*.swp
*.swo
*.swn

# Emacs
.clang_complete
.dir-locals.el
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ ifneq ($(ADD_LDFLAGS), NONE)
endif

#BIN = test/test_threaded_engine test/api_registry_test
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o
BIN = test/api_registry_test test/test_storage
OBJ = narray_op_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o
OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o
CUOBJ =
SLIB = lib/libmxnet.so
ALIB = lib/libmxnet.a
Expand Down Expand Up @@ -93,6 +93,7 @@ lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ)
lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ)

test/api_registry_test: test/api_registry_test.cc lib/libmxnet.a
test/test_storage: test/test_storage.cc lib/libmxnet.a
#test/test_threaded_engine: test/test_threaded_engine.cc api/libmxnet.a

$(BIN) :
Expand Down
4 changes: 2 additions & 2 deletions doc/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ WARN_LOGFILE =
# spaces.
# Note: If this tag is empty the current directory is searched.

INPUT = include
INPUT = include src/common

# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
Expand Down Expand Up @@ -1974,7 +1974,7 @@ INCLUDE_FILE_PATTERNS =
# recursively expanded use the := operator instead of the = operator.
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.

PREDEFINED =
PREDEFINED = MXNET_USE_CUDA DMLC_USE_CXX11

# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this
# tag can be used to specify a list of macro names that should be expanded. The
Expand Down
10 changes: 5 additions & 5 deletions include/mxnet/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ class NArray {
/*! \brief the real data chunk that backs NArray */
struct Chunk {
/*! \brief storage handlefrom storage engine */
StorageManager::Handle shandle;
Storage::Handle shandle;
/*! \brief variable from DAG engine */
DAGEngine::Variable var;
/*! \brief holds the data content */
TBlob data;
/*!
* \brief if this is true, this means the data do not come
* from StorageManager, and do not need to be freed
* from Storage, and do not need to be freed
*/
bool static_data;
/*! \brief whether allocation is delayed */
Expand Down Expand Up @@ -163,7 +163,7 @@ class NArray {
/*! \brief check if delay alloc is on, do alloc if not yet done */
inline void CheckAndAlloc(void) {
if (delay_alloc) {
shandle = StorageManager::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx);
shandle = Storage::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx);
data = TBlob(static_cast<real_t*>(shandle.dptr), data.shape_, shandle.ctx.dev_mask);
delay_alloc = false;
}
Expand All @@ -174,9 +174,9 @@ class NArray {
DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var);
} else {
CHECK(!delay_alloc) << "deleted before allocation";
StorageManager::Handle h = this->shandle;
Storage::Handle h = this->shandle;
DAGEngine::Get()->PushDelete([h](RunContext s) {
StorageManager::Get()->Free(h);
Storage::Get()->Free(h);
}, shandle.ctx, var);
}
}
Expand Down
6 changes: 6 additions & 0 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ class OperatorProperty {
* return {{out_data[0], in_data[0]}};
* }
* \endcode
* \param in_data The input data in forward pass.
* \param out_data The output data in forward pass.
* \return list of pair of integers taken from the inputs vector,
* indicating possible in place operations.
*/
Expand All @@ -273,6 +275,10 @@ class OperatorProperty {
* return {in_grad[0], in_data[0]}};
* }
* \endcode
* \param in_data The input data in forward pass.
* \param out_data The output data in forward pass.
* \param in_grad Gradient of inputs in backward pass.
* \param out_grad Gradient of outputs in backward pass.
* \return list of pair of integers taken from the inputs vector,
* indicating possible in place operations.
*/
Expand Down
75 changes: 50 additions & 25 deletions include/mxnet/storage.h
Original file line number Diff line number Diff line change
@@ -1,45 +1,70 @@
/*!
* Copyright (c) 2015 by Contributors
* Copyright (c) 2015 by Contributors
* \file storage.h
* \brief the memory allocator that manages the memory across multiple devices
* \brief Storage manager across multiple devices.
*/
#ifndef MXNET_STORAGE_H_
#define MXNET_STORAGE_H_
#include "./base.h"
#include "./context.h"

#include <memory>
#include "base.h"
#include "context.h"

namespace mxnet {
/*! \brief memory allocator of storage */
class StorageManager {

/*!
* \brief Storage manager across multiple devices.
*/
class Storage {
public:
/*!
* \brief storage handle the represents storage information
* \brief Storage handle.
*/
struct Handle {
/*! \brief pointer to the data */
void *dptr;
/*! \brief context information about device and deviceID */
Context ctx;
/*!
* \brief internal handle reserved for manager,
* user should not change or use this
* \brief Pointer to the data.
*/
void *handle_;
void* dptr;
/*!
* \brief Size of the storage.
*/
size_t size;
/*!
* \brief Context information about device and ID.
*/
Context ctx;
};
/*!
* \brief allocate a new contiguous memory for a given size
* \param size the total size of memory in bytes
* \param ctx context information about the device and deviceID
* \return Handle struct
* \brief Allocate a new contiguous memory for a given size.
* \param size Total size of memory in bytes.
* \param ctx Context information about the device and ID.
* \return Handle struct.
*/
virtual Handle Alloc(size_t size, Context ctx) = 0;
Handle Alloc(size_t size, Context ctx);
/*!
* \brief free the space represened the handle
* \param handle the handle to memory to be freed
* \brief Free storage.
* \param handle Handle struect.
*/
virtual void Free(Handle handle) = 0;
/*! \return storage manager singleton */
static StorageManager *Get();
}; // class StorageManager
void Free(Handle handle);
/*!
* \brief Destructor.
*/
~Storage();
/*!
* \return Storage singleton.
*/
static Storage* Get();

private:
/*!
* \brief Hidden constructors.
*/
Storage();
struct Impl;
std::unique_ptr<Impl> impl_;
DISALLOW_COPY_AND_ASSIGN(Storage);
}; // class Storage

} // namespace mxnet

#endif // MXNET_STORAGE_H_
3 changes: 3 additions & 0 deletions src/common/concurrent_blocking_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <thread>
#include <cstdio>

/*!
* \brief Common components.
*/
namespace common {

/*!
Expand Down
157 changes: 157 additions & 0 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*!
* Copyright (c) 2015 by Contributors
* \file cuda_utils.h
* \brief CUDA debugging utilities.
*/
#ifndef MXNET_COMMON_CUDA_UTILS_H_
#define MXNET_COMMON_CUDA_UTILS_H_

#include <dmlc/logging.h>

#if MXNET_USE_CUDA

#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <curand.h>

namespace common {

/*!
* \brief CUDA utilities.
*/
namespace cuda {

/*!
* \brief Get string representation of cuBLAS errors.
* \param error The error.
* \return String representation.
*/
inline const char* CublasGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
default:
break;
}
return "Unknown cuBLAS status";
}

/*!
* \brief Get string representation of cuRAND errors.
* \param status The status.
* \return String representation.
*/
inline const char* CurandGetErrorString(curandStatus_t status) {
switch (status) {
case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
}
return "Unknown cuRAND status";
}

} // namespace cuda
} // namespace common

/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected CUDA call.
* \param func Expression to call.
*
* It checks for CUDA errors after invocation of the expression.
*/
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK_EQ(e, cudaSuccess) << "CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected cuBLAS call.
* \param func Expression to call.
*
* It checks for cuBLAS errors after invocation of the expression.
*/
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBLAS: " << common::cuda::CublasGetErrorString(e); \
}

/*!
* \brief Protected cuRAND call.
* \param func Expression to call.
*
* It checks for cuRAND errors after invocation of the expression.
*/
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
<< "cuRAND: " << common::cuda::CurandGetErrorString(e); \
}

#endif // MXNET_USE_CUDA

#if MXNET_USE_CUDNN

#include <cudnn.h>

#define CUDNN_CALL(func) \
{ \
cudnnStatus_t e = (func); \
CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
}

#endif // MXNET_USE_CUDNN

#endif // MXNET_COMMON_CUDA_UTILS_H_
Loading

0 comments on commit 9f4d31c

Please sign in to comment.