Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix average pooling kernel size assignment error
Browse files Browse the repository at this point in the history
modify white space and other format errors

remove wrap line whitespace format error

remove whitespace at the end of line183

change error message

add default pooling type to pool_enum::kMaxPooling

add pooling without kernel test cases

adjust pooling parameter order and add associated test points

remove wrong error test points

ignore kernel size check if global_pool is assigned to be true

modify whitespace

line length adjust

adjust linelength

finally learned to use cpplint

switch off all shape checks if global_pool is assigned

parse parameter when global_pool used

modify pooling shape inference logic

change a way to infer pooling shape

add push oshape

change kernel shape

prepare pooling parameter shapes

check lint

pooling parameters preparation

modify kernel shape computation method

modify a bit pooling_v1

more modification of pooling_v1

remove "avg pool"

tiny changes

change pooling args order back

use size_t instead of int

use changed order and only try tiny changes

try no kernel indicated to python interface with original order

useless modify for recommit
  • Loading branch information
CoinCheung committed Mar 23, 2018
1 parent 09281c7 commit d6c0da8
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 148 deletions.
37 changes: 25 additions & 12 deletions src/operator/nn/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
.describe("Turn off cudnn pooling and use MXNet pooling operator. ");

DMLC_DECLARE_FIELD(kernel)
DMLC_DECLARE_FIELD(kernel).set_default(TShape()) // add default value here
.enforce_nonzero()
.describe("Pooling kernel size: (y, x) or (d, y, x)");

DMLC_DECLARE_FIELD(pool_type)
DMLC_DECLARE_FIELD(pool_type).set_default(pool_enum::kMaxPooling) // add default pooling method
.add_enum("max", pool_enum::kMaxPooling)
.add_enum("avg", pool_enum::kAvgPooling)
.add_enum("sum", pool_enum::kSumPooling)
Expand Down Expand Up @@ -132,19 +132,23 @@ class PoolingOp {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TShape& ishape = in_data.shape_;
TShape kernel = param_.kernel;
TShape padding = param_.pad;
TShape stride = param_.stride;
if (param_.global_pool) {
for (index_t i = 0; i < padding.ndim(); i++) {
kernel = TShape(ishape.data() + 2,
ishape.data() + ishape.ndim());
padding = TShape(ishape.ndim() - 2);
for (index_t i = 0; i < ishape.ndim() - 2; i++) {
padding[i] = 0;
}
stride = TShape(ishape.ndim() - 2);
}

pool(s, in_data.dptr<DType>(), in_data.shape_, out_data.shape_,
param_.global_pool?
TShape(ishape.data()+ishape.ndim()-param_.kernel.ndim(), ishape.data()+ishape.ndim())
: param_.kernel,
kernel,
padding,
param_.global_pool? TShape(param_.kernel.ndim()) : param_.stride,
stride,
param_.pool_type, req, out_data.dptr<DType>());
}

Expand All @@ -154,20 +158,24 @@ class PoolingOp {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TShape& ishape = in_data.shape_;
TShape kernel = param_.kernel;
TShape padding = param_.pad;
TShape stride = param_.stride;
if (param_.global_pool) {
for (index_t i = 0; i < padding.ndim(); i++) {
kernel = TShape(ishape.data() + 2,
ishape.data() + ishape.ndim());
padding = TShape(ishape.ndim() - 2);
for (index_t i = 0; i < ishape.ndim() - 2; i++) {
padding[i] = 0;
}
stride = TShape(ishape.ndim() - 2);
}

unpool(s, out_grad.dptr<DType>(), in_data.dptr<DType>(), out_data.dptr<DType>(),
in_grad.shape_, out_grad.shape_,
param_.global_pool?
TShape(ishape.data()+ishape.ndim()-param_.kernel.ndim(), ishape.data()+ishape.ndim())
: param_.kernel,
kernel,
padding,
param_.global_pool? TShape(param_.kernel.ndim()) : param_.stride,
stride,
param_.pool_type, req, in_grad.dptr<DType>());
}

Expand All @@ -178,6 +186,11 @@ class PoolingOp {
template<typename xpu, typename DType>
PoolingOp<xpu, DType> &GetPoolingOp(const PoolingParam &param) {
static thread_local PoolingOp<xpu, DType> op;
// check if filter size assigned correctly
if (param.global_pool == false) {
CHECK_GT(param.kernel.ndim(), 0U)
<< "You need to set the kernel size if global pooling is not used";
}
op.Init(param);
return op;
}
Expand Down
155 changes: 76 additions & 79 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,14 @@ static void PoolingParamParser(nnvm::NodeAttrs *attrs) {
if (param.stride.ndim() == 0) param.stride = Shape2(1, 1);
if (param.pad.ndim() == 0) param.pad = Shape2(0, 0);
} else {
CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim()
<< "D pooling not supported";
// ignore kernel size only if global_pool not assigned false
if (param.global_pool == false) {
CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim()
<< "D pooling not supported";
}
if (param.stride.ndim() == 0) param.stride = Shape3(1, 1, 1);
if (param.pad.ndim() == 0) param.pad = Shape3(0, 0, 0);
}
CHECK_EQ(param.stride.ndim(), param.kernel.ndim())
<< "stride and kernel should have the same length";
CHECK_EQ(param.pad.ndim(), param.kernel.ndim())
<< "pad and kernel should have the same length";
attrs->parsed = std::move(param);
}

Expand Down Expand Up @@ -98,28 +97,37 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
CHECK_LE(dshape.ndim(), 5U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
TShape oshape = dshape;
if (dshape.ndim() == 0) return false;
if (param.kernel.ndim() == 1) {
if (param.global_pool) {
for (size_t i{2}; i < dshape.ndim(); i++)
oshape[i] = 1;
out_shape->clear();
out_shape->push_back(oshape); // save output shape
#if MXNET_USE_MKLDNN == 1
if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param))
out_shape->push_back(oshape); // for workspace
#endif
} else if (param.kernel.ndim() == 1) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)";
if (param.global_pool) {
oshape[2] = 1;
CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0])
<< "kernel size (" << param.kernel[0] << ") exceeds input ("
<< dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0])
<< ")";
if (param.pooling_convention == pool_enum::kValid) {
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
} else {
CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0])
<< "kernel size (" << param.kernel[0] << ") exceeds input ("
<< dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0])
<< ")";
if (param.pooling_convention == pool_enum::kValid) {
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
} else {
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
}
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
}
out_shape->clear();
out_shape->push_back(oshape); // save output shape
Expand All @@ -130,35 +138,30 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
} else if (param.kernel.ndim() == 2) {
CHECK_EQ(dshape.ndim(), 4U)
<< "Pooling: Input data should be 4D in (batch, channel, y, x)";
if (param.global_pool) {
oshape[2] = 1;
oshape[3] = 1;
CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0])
<< "kernel size (" << param.kernel[0] << ") exceeds input ("
<< dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0])
<< ")";
CHECK(param.kernel[1] <= dshape[3] + 2 * param.pad[1])
<< "kernel size (" << param.kernel[1] << ") exceeds input ("
<< dshape[3] << " padded to " << (dshape[3] + 2 * param.pad[1])
<< ")";
if (param.pooling_convention == pool_enum::kValid) {
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
oshape[3] = 1 +
(dshape[3] + 2 * param.pad[1] - param.kernel[1]) /
param.stride[1];
} else {
CHECK(param.kernel[0] <= dshape[2] + 2 * param.pad[0])
<< "kernel size (" << param.kernel[0] << ") exceeds input ("
<< dshape[2] << " padded to " << (dshape[2] + 2 * param.pad[0])
<< ")";
CHECK(param.kernel[1] <= dshape[3] + 2 * param.pad[1])
<< "kernel size (" << param.kernel[1] << ") exceeds input ("
<< dshape[3] << " padded to " << (dshape[3] + 2 * param.pad[1])
<< ")";
if (param.pooling_convention == pool_enum::kValid) {
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
oshape[3] = 1 +
(dshape[3] + 2 * param.pad[1] - param.kernel[1]) /
param.stride[1];
} else {
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
oshape[3] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[3] + 2 * param.pad[1] -
param.kernel[1]) /
param.stride[1]));
}
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
oshape[3] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[3] + 2 * param.pad[1] -
param.kernel[1]) /
param.stride[1]));
}
out_shape->clear();
out_shape->push_back(oshape); // save output shape
Expand All @@ -175,35 +178,29 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "kernel size exceeds input";
CHECK_LE(param.kernel[2], dshape[4] + 2 * param.pad[2])
<< "kernel size exceeds input";
if (param.global_pool) {
oshape[2] = 1;
oshape[3] = 1;
oshape[4] = 1;
if (param.pooling_convention == pool_enum::kValid) {
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
oshape[3] = 1 +
(dshape[3] + 2 * param.pad[1] - param.kernel[1]) /
param.stride[1];
oshape[4] = 1 +
(dshape[4] + 2 * param.pad[2] - param.kernel[2]) /
param.stride[2];
} else {
if (param.pooling_convention == pool_enum::kValid) {
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
oshape[3] = 1 +
(dshape[3] + 2 * param.pad[1] - param.kernel[1]) /
param.stride[1];
oshape[4] = 1 +
(dshape[4] + 2 * param.pad[2] - param.kernel[2]) /
param.stride[2];
} else {
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
oshape[3] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[3] + 2 * param.pad[1] -
param.kernel[1]) /
param.stride[1]));
oshape[4] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[4] + 2 * param.pad[2] -
param.kernel[2]) /
param.stride[2]));
}
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
oshape[3] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[3] + 2 * param.pad[1] -
param.kernel[1]) /
param.stride[1]));
oshape[4] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[4] + 2 * param.pad[2] -
param.kernel[2]) /
param.stride[2]));
}

out_shape->clear();
Expand Down
Loading

0 comments on commit d6c0da8

Please sign in to comment.