Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Refactored useful checker macros into a separate header file #230

Merged
merged 4 commits into from
Feb 23, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- PR #208: Issue ml-common-3: Math.h: swap thrust::for_each with binaryOp,unaryOp
- PR #209: Simplify README.md, move build instructions to BUILD.md
- PR #225: Support for generating random integers
- PR #230: Refactored the cuda_utils header

## Bug Fixes

Expand Down
3 changes: 3 additions & 0 deletions cuML/src/dbscan/labelling/pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/

#pragma once

#include "utils.h"

namespace Dbscan {
namespace Label {

Expand Down
129 changes: 1 addition & 128 deletions ml-prims/src/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,10 @@
#pragma once

#include <stdint.h>
#include <cstdio>
#include <stdexcept>
#include <string>
#include <iostream>
#include "utils.h"

namespace MLCommon {

/** macro to throw a c++ std::runtime_error */
#define THROW(fmt, ...) \
do { \
std::string msg; \
char errMsg[2048]; \
std::sprintf(errMsg, "Exception occured! file=%s line=%d: ", __FILE__, \
__LINE__); \
msg += errMsg; \
std::sprintf(errMsg, fmt, ##__VA_ARGS__); \
msg += errMsg; \
throw std::runtime_error(msg); \
} while (0)

/** macro to check for a conditional and assert on failure */
#define ASSERT(check, fmt, ...) \
do { \
if (!(check)) \
THROW(fmt, ##__VA_ARGS__); \
} while (0)

/** check for cuda runtime API errors and assert accordingly */
#define CUDA_CHECK(call) \
do { \
cudaError_t status = call; \
ASSERT(status == cudaSuccess, "FAIL: call='%s'. Reason:%s\n", #call, \
cudaGetErrorString(status)); \
} while (0)

/** helper macro for device inlined functions */
#define DI inline __device__
#define HDI inline __host__ __device__
Expand Down Expand Up @@ -102,60 +71,6 @@ constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) {
return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret);
}

/** cuda malloc */
template <typename Type>
void allocate(Type *&ptr, size_t len, bool setZero = false) {
CUDA_CHECK(cudaMalloc((void **)&ptr, sizeof(Type) * len));
if (setZero)
CUDA_CHECK(cudaMemset(ptr, 0, sizeof(Type) * len));
}

/** Helper function to calculate need memory for allocate to store dense matrix.
* @param rows number of rows in matrix
* @param columns number of columns in matrix
* @return need number of items to allocate via allocate()
* @sa allocate()
*/
inline size_t allocLengthForMatrix(size_t rows, size_t columns) {
return rows * columns;
}

/** performs a host to device copy */
template <typename Type>
void updateDevice(Type *dPtr, const Type *hPtr, size_t len,
cudaStream_t stream = 0) {
CUDA_CHECK(
cudaMemcpy(dPtr, hPtr, len * sizeof(Type), cudaMemcpyHostToDevice));
}

template <typename Type>
void updateDeviceAsync(Type *dPtr, const Type *hPtr, size_t len,
cudaStream_t stream) {
CUDA_CHECK(cudaMemcpyAsync(dPtr, hPtr, len * sizeof(Type),
cudaMemcpyHostToDevice, stream));
}

/** performs a device to host copy */
template <typename Type>
void updateHost(Type *hPtr, const Type *dPtr, size_t len,
cudaStream_t stream = 0) {
CUDA_CHECK(
cudaMemcpy(hPtr, dPtr, len * sizeof(Type), cudaMemcpyDeviceToHost));
}

template <typename Type>
void updateHostAsync(Type *hPtr, const Type *dPtr, size_t len,
cudaStream_t stream) {
CUDA_CHECK(cudaMemcpyAsync(hPtr, dPtr, len * sizeof(Type),
cudaMemcpyDeviceToHost, stream));
}

template <typename Type>
void copy(Type* dPtr1, const Type* dPtr2, size_t len) {
CUDA_CHECK(cudaMemcpy(dPtr1, dPtr2, len*sizeof(Type),
cudaMemcpyDeviceToDevice));
}

/** Device function to apply the input lambda across threads in the grid */
template <int ItemsPerThread, typename L>
DI void forEach(int num, L lambda) {
Expand Down Expand Up @@ -478,46 +393,4 @@ DI T shfl_xor(T val, int laneMask, int width = WarpSize,
#endif
}


/**
* @defgroup Debug utils for debug device code
* @{
*/
template<class T, class OutStream>
void myPrintHostVector(const char * variableName, const T * hostMem, size_t componentsCount, OutStream& out)
{
out << variableName << "=[";
for (size_t i = 0; i < componentsCount; ++i)
{
if (i != 0)
out << ",";
out << hostMem[i];
}
out << "];\n";
}

template<class T>
void myPrintHostVector(const char * variableName, const T * hostMem, size_t componentsCount)
{
myPrintHostVector(variableName, hostMem, componentsCount, std::cout);
std::cout.flush();
}

template<class T, class OutStream>
void myPrintDevVector(const char * variableName, const T * devMem, size_t componentsCount, OutStream& out)
{
T* hostMem = new T[componentsCount];
CUDA_CHECK(cudaMemcpy(hostMem, devMem, componentsCount * sizeof(T), cudaMemcpyDeviceToHost));
myPrintHostVector(variableName, hostMem, componentsCount, out);
delete []hostMem;
}

template<class T>
void myPrintDevVector(const char * variableName, const T * devMem, size_t componentsCount)
{
myPrintDevVector(variableName, devMem, componentsCount, std::cout);
std::cout.flush();
}
/** @} */

} // namespace MLCommon
212 changes: 212 additions & 0 deletions ml-prims/src/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdio>
#include <execinfo.h>
#include <stdexcept>
#include <string>
#include <iostream>
#include <sstream>
#include <cuda_runtime.h>

namespace MLCommon {

/** base exception class for the cuML or ml-prims project */
class Exception : public std::exception {
public:
/** default ctor */
Exception() throw(): std::exception(), msg() {}

/** copy ctor */
Exception(const Exception& src) throw(): std::exception(), msg(src.what()) {
collectCallStack();
}

/** ctor from an input message */
Exception(const std::string& _msg) throw(): std::exception(), msg(_msg) {
collectCallStack();
}

/** dtor */
virtual ~Exception() throw() {}

/** get the message associated with this exception */
virtual const char* what() const throw() { return msg.c_str(); }

private:
/** message associated with this exception */
std::string msg;

/** append call stack info to this exception's message for ease of debug */
// Courtesy: https://www.gnu.org/software/libc/manual/html_node/Backtraces.html
void collectCallStack() throw() {
#ifdef __GNUC__
const int MaxStackDepth = 64;
void* stack[MaxStackDepth];
auto depth = backtrace(stack, MaxStackDepth);
std::ostringstream oss;
oss << std::endl << "Obtained " << depth << " stack frames" << std::endl;
char** strings = backtrace_symbols(stack, depth);
if (strings == nullptr) {
oss << "But no stack trace could be found!" << std::endl;
msg += oss.str();
return;
}
///@todo: support for demangling of C++ symbol names
for (int i = 0; i < depth; ++i) {
oss << "#" << i << " in " << strings[i] << std::endl;
}
free(strings);
msg += oss.str();
#endif // __GNUC__
}
};

/** macro to throw a runtime error */
#define THROW(fmt, ...) \
do { \
std::string msg; \
char errMsg[2048]; \
std::sprintf(errMsg, "Exception occured! file=%s line=%d: ", __FILE__, \
__LINE__); \
msg += errMsg; \
std::sprintf(errMsg, fmt, ##__VA_ARGS__); \
msg += errMsg; \
throw MLCommon::Exception(msg); \
} while (0)

/** macro to check for a conditional and assert on failure */
#define ASSERT(check, fmt, ...) \
do { \
if (!(check)) \
THROW(fmt, ##__VA_ARGS__); \
} while (0)

/** check for cuda runtime API errors and assert accordingly */
#define CUDA_CHECK(call) \
teju85 marked this conversation as resolved.
Show resolved Hide resolved
do { \
cudaError_t status = call; \
ASSERT(status == cudaSuccess, "FAIL: call='%s'. Reason:%s\n", #call, \
cudaGetErrorString(status)); \
} while (0)


/**
* @defgroup Copy Copy methods
* @{
*/
/** performs a host to device copy */
template <typename Type>
void updateDevice(Type *dPtr, const Type *hPtr, size_t len,
cudaStream_t stream = 0) {
CUDA_CHECK(cudaMemcpy(dPtr, hPtr, len * sizeof(Type),
cudaMemcpyHostToDevice));
}

/** performs an sync host to device copy */
template <typename Type>
void updateDeviceAsync(Type *dPtr, const Type *hPtr, size_t len,
cudaStream_t stream) {
CUDA_CHECK(cudaMemcpyAsync(dPtr, hPtr, len * sizeof(Type),
cudaMemcpyHostToDevice, stream));
}

/** performs a device to host copy */
template <typename Type>
void updateHost(Type *hPtr, const Type *dPtr, size_t len,
cudaStream_t stream = 0) {
CUDA_CHECK(cudaMemcpy(hPtr, dPtr, len * sizeof(Type),
cudaMemcpyDeviceToHost));
}

/** performs an async device to host copy */
template <typename Type>
void updateHostAsync(Type *hPtr, const Type *dPtr, size_t len,
cudaStream_t stream) {
CUDA_CHECK(cudaMemcpyAsync(hPtr, dPtr, len * sizeof(Type),
cudaMemcpyDeviceToHost, stream));
}

/** performs a device to device copy */
template <typename Type>
void copy(Type* dPtr1, const Type* dPtr2, size_t len) {
CUDA_CHECK(cudaMemcpy(dPtr1, dPtr2, len*sizeof(Type),
cudaMemcpyDeviceToDevice));
}
/** @} */

/** Helper function to calculate need memory for allocate to store dense matrix.
* @param rows number of rows in matrix
* @param columns number of columns in matrix
* @return need number of items to allocate via allocate()
* @sa allocate()
*/
inline size_t allocLengthForMatrix(size_t rows, size_t columns) {
return rows * columns;
}

/** cuda malloc */
template <typename Type>
void allocate(Type *&ptr, size_t len, bool setZero = false) {
CUDA_CHECK(cudaMalloc((void **)&ptr, sizeof(Type) * len));
if (setZero)
CUDA_CHECK(cudaMemset(ptr, 0, sizeof(Type) * len));
}

/**
* @defgroup Debug utils for debug device code
* @{
*/
template<class T, class OutStream>
void myPrintHostVector(const char * variableName, const T * hostMem, size_t componentsCount, OutStream& out)
{
out << variableName << "=[";
for (size_t i = 0; i < componentsCount; ++i)
{
if (i != 0)
out << ",";
out << hostMem[i];
}
out << "];\n";
}

template<class T>
void myPrintHostVector(const char * variableName, const T * hostMem, size_t componentsCount)
{
myPrintHostVector(variableName, hostMem, componentsCount, std::cout);
std::cout.flush();
}

template<class T, class OutStream>
void myPrintDevVector(const char * variableName, const T * devMem, size_t componentsCount, OutStream& out)
{
T* hostMem = new T[componentsCount];
CUDA_CHECK(cudaMemcpy(hostMem, devMem, componentsCount * sizeof(T), cudaMemcpyDeviceToHost));
myPrintHostVector(variableName, hostMem, componentsCount, out);
delete []hostMem;
}

template<class T>
void myPrintDevVector(const char * variableName, const T * devMem, size_t componentsCount)
{
myPrintDevVector(variableName, devMem, componentsCount, std::cout);
std::cout.flush();
}
/** @} */

}; // end namespace MLCommon
2 changes: 1 addition & 1 deletion ml-prims/test/cuda_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace MLCommon {

TEST(Utils, Assert) {
ASSERT_NO_THROW(ASSERT(1 == 1, "Should not assert!"));
ASSERT_THROW(ASSERT(1 != 1, "Should assert!"), std::runtime_error);
ASSERT_THROW(ASSERT(1 != 1, "Should assert!"), Exception);
}

TEST(Utils, CudaCheck) { ASSERT_NO_THROW(CUDA_CHECK(cudaFree(nullptr))); }
Expand Down