Skip to content

Commit

Permalink
[BUGFIX] Add checks in BatchNorm's infer shape (apache#20415)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych authored and chinakook committed Aug 1, 2021
1 parent 5f0501b commit 1724eb9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
18 changes: 9 additions & 9 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,15 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,

const index_t channelCount = dshape[channelAxis];

in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
in_shape->at(batchnorm::kInMovingVar) = mxnet::TShape(Shape1(channelCount)); // kMovingVar

out_shape->clear();
out_shape->push_back(dshape); // kOut
out_shape->push_back(Shape1(channelCount)); // kMean
out_shape->push_back(Shape1(channelCount)); // kVar
SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kGamma, Shape1(channelCount));
SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kBeta, Shape1(channelCount));
SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kInMovingMean, Shape1(channelCount)); // kMovingMean
SHAPE_ASSIGN_CHECK(*in_shape, batchnorm::kInMovingVar, Shape1(channelCount)); // kMovingVar


SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kOut, dshape);
SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kMean, Shape1(channelCount));
SHAPE_ASSIGN_CHECK(*out_shape, batchnorm::kVar, Shape1(channelCount));

return true;
}
Expand Down
30 changes: 15 additions & 15 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
static_cast<index_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<index_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape);
Expand Down Expand Up @@ -195,7 +195,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const mkldnn::memory &weight_mem = fwd.GetWeight();
float* weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle());

nnvm::dim_t channels_ = data.shape()[1];
index_t channels_ = data.shape()[1];
CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(float) * 2);
float* weight_ptr = gamma.data().dptr<float>();
float* bias_ptr = beta.data().dptr<float>();
Expand All @@ -204,13 +204,13 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
memcpy(weight_buf, weight_ptr, copy_size);
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
} else if (IsBNWriting(req[batchnorm::kGamma])) {
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
weight_buf[i] = 1.0f;
weight_ptr[i] = 1.0f;
weight_buf[channels_ + i] = bias_ptr[i]; // bias
}
} else {
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
weight_buf[i] = 1.0f;
weight_buf[channels_ + i] = bias_ptr[i]; // bias
}
Expand All @@ -237,7 +237,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
float* inmean = aux_states[batchnorm::kMovingMean].data().dptr<float>();
float* invar = aux_states[batchnorm::kMovingVar].data().dptr<float>();
// to align with origin implmentation: batch_norm.cc: L164
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
omean[i] = inmean[i];
ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps);
}
Expand All @@ -254,7 +254,7 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
MKLDNNStream::Get()->Submit();

float* ovar = outVar.data().dptr<float>();
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps);
}
}
Expand Down Expand Up @@ -357,10 +357,10 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
if (param.axis != 1 || shape.ndim() != 4) {
// reshape to (N, C, 1, D)
mxnet::TShape new_shape{
static_cast<dim_t>(shape.ProdShape(0, real_axis)),
static_cast<index_t>(shape.ProdShape(0, real_axis)),
shape[real_axis],
1,
static_cast<dim_t>(shape.ProdShape(real_axis + 1,
static_cast<index_t>(shape.ProdShape(real_axis + 1,
static_cast<int>(shape.ndim())))
};
data = data.Reshape(new_shape);
Expand All @@ -384,15 +384,15 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const NDArray &gamma = in_data[batchnorm::kGamma];
const NDArray &beta = in_data[batchnorm::kBeta];
DType *weight_buf = reinterpret_cast<DType *>(bwd.GetWeight().get_data_handle());
nnvm::dim_t channels_ = data.shape()[1];
index_t channels_ = data.shape()[1];
DType *weight_ptr = gamma.data().dptr<DType>();
DType* bias_ptr = beta.data().dptr<DType>();
const size_t copy_size = sizeof(DType) * channels_;
if (!param.fix_gamma) {
memcpy(weight_buf, weight_ptr, copy_size);
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
} else {
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
weight_buf[i] = static_cast<DType>(1.0f);
}
memcpy(&weight_buf[channels_], bias_ptr, copy_size);
Expand Down Expand Up @@ -422,7 +422,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
DType *tmp_var_ptr = reinterpret_cast<DType *>(var_mem.get_data_handle());

DType minus_mom = (1.0f - param.momentum);
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum +
out_mean_ptr[i] * minus_mom;
float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps);
Expand Down Expand Up @@ -451,13 +451,13 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
if (req[batchnorm::kGamma] != kAddTo) {
memcpy(w_grad_1, gw_buf, copy_size);
} else {
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
w_grad_1[i] += gw_buf[i];
}
}
}
} else {
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
}
}
Expand All @@ -468,7 +468,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
} else {
DType *grad_beta = &gw_buf[channels_];
for (int i = 0; i < channels_; i++) {
for (index_t i = 0; i < channels_; i++) {
w_grad_2[i] += grad_beta[i];
}
}
Expand Down

0 comments on commit 1724eb9

Please sign in to comment.