Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-33] Enhance mkldnn pooling to support full convention (#11047)
Browse files Browse the repository at this point in the history
* fix mkldnn pooling to support full convention

* backward with full convention

* fix

* add pooling test for full convention

* add function for computing padding size

* fix unit test

* only support max-pooling

* fix pooling bwd

* address review comment
  • Loading branch information
TaoLv authored and szha committed Nov 17, 2018
1 parent ac57ce3 commit dc3648b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
18 changes: 5 additions & 13 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,12 @@ inline bool SupportMKLDNNPooling(const PoolingParam &param,
if (!ret)
return false;

if (param.pooling_convention == pool_enum::kValid)
if (param.pooling_convention == pool_enum::kValid) {
return true;
else
return false;

// need to support pooling convention full
// https://issues.apache.org/jira/browse/MXNET-33
#if 0
if (((dshape[2] + 2 * param.pad[0] - param.kernel[0]) % param.stride[0] == 0) &&
((dshape[3] + 2 * param.pad[1] - param.kernel[1]) % param.stride[1] == 0))
return true;
else
return false;
#endif
} else {
// currently, only max-pooling is supported for full convention
return param.pool_type == pool_enum::kMaxPooling;
}
}

inline bool MKLDNNRequireWorkspace(const PoolingParam &param) {
Expand Down
40 changes: 32 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
}
}

static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) {
if ((x + padl + padr - k) % s != 0) {
return (padr + s - ((x + padl + padr - k) % s));
} else {
return padr;
}
}

mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
const PoolingParam &param, const bool is_train, const memory::desc &data_md,
const memory::desc &out_md) {
Expand All @@ -154,11 +162,17 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_);
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_);
}

const mkldnn::engine engine = CpuEngine::Get()->get_engine();
if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
}

if (pad_t_ != 0 || pad_l_ != 0) {
CHECK(param.pool_type == pool_enum::kAvgPooling ||
param.pool_type == pool_enum::kMaxPooling)
Expand All @@ -167,7 +181,6 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
CHECK_LT(pad_t_, kernel_h_);
}


const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
if (is_train && alg != algorithm::pooling_avg) {
Expand Down Expand Up @@ -227,17 +240,22 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_);
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_);
}

if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
}

if (pad_t_ != 0 || pad_l_ != 0) {
CHECK(param.pool_type == pool_enum::kAvgPooling ||
param.pool_type == pool_enum::kMaxPooling)
<< "Padding implemented only for average and max pooling.";
CHECK_LT(pad_l_, kernel_w_);
CHECK_LT(pad_t_, kernel_h_);
CHECK(param.pool_type == pool_enum::kAvgPooling ||
param.pool_type == pool_enum::kMaxPooling)
<< "Padding implemented only for average and max pooling.";
CHECK_LT(pad_l_, kernel_w_);
CHECK_LT(pad_t_, kernel_h_);
}

const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
Expand Down Expand Up @@ -353,6 +371,12 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
int pad_t_ = param.pad[0], pad_b_ = param.pad[0];
int pad_l_ = param.pad[1], pad_r_ = param.pad[1];
int stride_h_ = param.stride[0], stride_w_ = param.stride[1];

if (param.pooling_convention == pool_enum::kFull) {
pad_b_ = GetPaddingSizeFull(data_md.data.dims[2], pad_t_, pad_b_, kernel_h_, stride_h_);
pad_r_ = GetPaddingSizeFull(data_md.data.dims[3], pad_l_, pad_r_, kernel_w_, stride_w_);
}

if (param.global_pool) {
pad_t_ = pad_b_ = pad_l_ = pad_r_ = 0;
stride_h_ = stride_w_ = 1;
Expand Down
29 changes: 29 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,35 @@ def test_3d_pooling(pool_type, p_value=2, count_include_pad=True):
test_3d_pooling('lp', p_value=3)


@with_seed()
def test_pooling_full_2d():
def test_pooling_full_2d_type(pool_type):
data = (2, 2, 10, 10)
kernel = (4, 5)
pad = (1, 2)
stride = (3, 4)

convention = 'full'
ctx_list = []
sym_list = []

# o_h = ceil((10 + 1 + 1 - 4) / 3) + 1 = 4
# o_w = ceil((10 + 2 + 2 - 5) / 4) + 1 = 4
ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
pooling_convention=convention, global_pool=False, name='pool'))

ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}})
sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type,
pooling_convention=convention, global_pool=False, name='pool'))

check_consistency(sym_list, ctx_list)

test_pooling_full_2d_type('max')
test_pooling_full_2d_type('avg')
test_pooling_full_2d_type('sum')


@with_seed()
def test_global_pooling():
def test_1d_pooling(pool_type, p_value=2):
Expand Down

0 comments on commit dc3648b

Please sign in to comment.