Skip to content

Commit

Permalink
Add layout flags (apache#169)
Browse files Browse the repository at this point in the history
* Add layout flags

* fix
  • Loading branch information
piiswrong authored and tqchen committed Nov 3, 2016
1 parent 1728adf commit bb0f8c7
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
92 changes: 92 additions & 0 deletions mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ extern "C" {
} \
}

#if !(MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 5)
typedef int cudnnTensorFormat_t;
#endif

#include "./half.h"
#include "./logging.h"
/*! \brief namespace for mshadow */
Expand All @@ -262,6 +266,7 @@ enum TypeFlag {

template<typename DType>
struct DataType;

template<>
struct DataType<float> {
static const int kFlag = kFloat32;
Expand Down Expand Up @@ -298,6 +303,63 @@ struct DataType<int32_t> {
/*! \brief type enum value for default real type */
const int default_type_flag = DataType<default_real_t>::kFlag;

enum LayoutFlag {
kNCHW = 0,
kNHWC,
kCHWN,

kNCDHW = 1 << 5,
kNDHWC,
kCDHWN
};

template<int layout>
struct LayoutType;

template<>
struct LayoutType<kNCHW> {
static const index_t kNdim = 4;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 5)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
#else
static const cudnnTensorFormat_t kCudnnFlag = -1;
#endif
};

template<>
struct LayoutType<kNHWC> {
static const index_t kNdim = 4;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 5)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
#else
static const cudnnTensorFormat_t kCudnnFlag = -1;
#endif
};

const int default_layout = kNCHW;

template<>
struct LayoutType<kNCDHW> {
static const index_t kNdim = 5;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 5)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
#else
static const cudnnTensorFormat_t kCudnnFlag = -1;
#endif
};

template<>
struct LayoutType<kNDHWC> {
static const index_t kNdim = 5;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 5)
static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
#else
static const cudnnTensorFormat_t kCudnnFlag = -1;
#endif
};

const int default_layout_5d = kNCDHW;

/*! \brief namespace for operators */
namespace op {
// binary operator
Expand Down Expand Up @@ -604,6 +666,36 @@ struct minimum {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
switch (layout) { \
case mshadow::kNCHW: \
{ \
const int Layout = kNCHW; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kNHWC: \
{ \
const int Layout = kNHWC; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kNCDHW: \
{ \
const int Layout = kNCDHW; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kNDHWC: \
{ \
const int Layout = kNDHWC; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown layout enum " << layout; \
}

/*! \brief get data type size from type enum */
inline size_t mshadow_sizeof(int type) {
int size = 0;
Expand Down
65 changes: 65 additions & 0 deletions mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,71 @@ MSHADOW_XINLINE Shape<5> Shape5(index_t s0, index_t s1, index_t s2,
s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4;
return s;
}

inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) {
Shape<4> dst;
switch (src_layout) {
case kNCHW:
dst = src;
break;
case kNHWC:
dst[0] = src[0];
dst[2] = src[1];
dst[3] = src[2];
dst[1] = src[3];
break;
default:
LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
}
Shape<4> dst2;
switch (dst_layout) {
case kNCHW:
return dst;
case kNHWC:
dst2[0] = dst[0];
dst2[1] = dst[2];
dst2[2] = dst[3];
dst2[3] = dst[1];
break;
default:
LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
}
return dst2;
}

inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) {
Shape<5> dst;
switch (src_layout) {
case kNCDHW:
dst = src;
break;
case kNDHWC:
dst[0] = src[0];
dst[2] = src[1];
dst[3] = src[2];
dst[4] = src[3];
dst[1] = src[4];
break;
default:
LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
}
Shape<5> dst2;
switch (dst_layout) {
case kNCDHW:
return dst;
case kNDHWC:
dst2[0] = dst[0];
dst2[1] = dst[2];
dst2[2] = dst[3];
dst2[3] = dst[4];
dst2[4] = dst[1];
break;
default:
LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
}
return dst2;
}

/*!
* \brief computaion stream structure, used for asynchronize computation
*/
Expand Down

0 comments on commit bb0f8c7

Please sign in to comment.