Skip to content

Commit

Permalink
Fast cuDNN BatchNorm NHWC kernels support (apache#20615)
Browse files Browse the repository at this point in the history
* Fast cuDNN NHWC kernels support

* Fix lint errors

* Get rid of a warning

* Remove CuDNNBatchNorm from AMP lists

Co-authored-by: Vladimir Cherepanov <[email protected]>
  • Loading branch information
2 people authored and chinakook committed Feb 11, 2022
1 parent 173ca26 commit bf50cd4
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,10 @@ then set ``gamma`` to 1 and its gradient to 0.
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
#endif
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
.add_argument("beta", "NDArray-or-Symbol", "beta array")
Expand Down
36 changes: 8 additions & 28 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#define ADDTO_BETA_FLAG (1 << 8)

#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
#include "./cudnn/cudnn_batch_norm.h"
#endif

#include "../../common/cuda_utils.h"
Expand Down Expand Up @@ -928,11 +928,6 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu> *s,
(flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;

if (is_train_and_not_global_stats) {
#ifdef NDEBUG
constexpr bool SMALLER_THREADS = false;
#else
constexpr bool SMALLER_THREADS = true;
#endif
dim3 blocks(gradOutput.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
Expand Down Expand Up @@ -1082,20 +1077,7 @@ void BatchNormBackwardImpl(mshadow::Stream<gpu> *stream,
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}

#if MXNET_USE_CUDNN == 1
template<typename DType>
static CuDNNBatchNormOp<DType> &GetCuDNNOp(const BatchNormParam& param) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local CuDNNBatchNormOp<DType> op;
#else
static MX_THREAD_LOCAL CuDNNBatchNormOp<DType> op;
#endif
op.Init(param);
return op;
}
#endif

template<>
template <>
void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
Expand All @@ -1109,10 +1091,9 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
})
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[batchnorm::kData])) {
CudnnBatchNormForward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
Expand All @@ -1137,10 +1118,9 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
})
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) {
CudnnBatchNormBackward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
Expand Down
210 changes: 210 additions & 0 deletions src/operator/nn/cudnn/cudnn_batch_norm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file cudnn_batch_norm.cu
* \brief
* \author Junyuan Xie, Da Zheng
*/

#include "cudnn_batch_norm.h"

#include "../../../common/cuda_utils.h"

namespace mxnet {
namespace op {

#if MXNET_USE_CUDNN == 1

namespace {

struct Globals {
cudnnTensorDescriptor_t io_desc;
cudnnTensorDescriptor_t mean_desc;
bool internal_aux_states_lock = false;

static Globals& Get() {
thread_local Globals ret;
return ret;
}

Globals() {
CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc));
CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc));
}

~Globals() {
CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc));
CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc));
}
};

void SetDescriptors(const BatchNormParam& param, const TBlob& x) {
CHECK_GE(x.shape_.ndim(), 3);
CHECK(param.axis == 1 || param.axis == x.shape_.ndim() - 1);

cudnnTensorFormat_t format = param.axis == 1 ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC;
int n = x.shape_[0];
int c = x.shape_[param.axis];
size_t last_spatial_i = param.axis == 1 ? x.shape_.ndim() - 1 : x.shape_.ndim() - 2;
int w = x.shape_[last_spatial_i];
int h = x.shape_.ProdShape(last_spatial_i - (x.shape_.ndim() - 3), last_spatial_i);

MSHADOW_REAL_TYPE_SWITCH(x.type_flag_, DType, {
CUDNN_CALL(cudnnSetTensor4dDescriptor(Globals::Get().io_desc, format,
mshadow::DataType<DType>::kCudnnFlag, n, c, h, w));
})
CUDNN_CALL(cudnnDeriveBNTensorDescriptor(Globals::Get().mean_desc, Globals::Get().io_desc,
CUDNN_BATCHNORM_SPATIAL));
}

mshadow::TypeFlag ParamType(int x_type) {
auto xt = static_cast<mshadow::TypeFlag>(x_type);
return xt == mshadow::kFloat16 ? mshadow::kFloat32 : xt;
}

} // namespace

bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x) {
int n = x.shape_.ndim();
return n >= 3 && (param.axis == 1 || param.axis == n - 1);
}

void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 5);
if (ctx.is_train) {
CHECK_EQ(outputs.size(), 3);
CHECK_EQ(req.size(), 3);
} else {
CHECK_GE(outputs.size(), 1);
CHECK_GE(req.size(), 1);
}
CHECK_EQ(req[batchnorm::kOut], kWriteTo);
CHECK_GE(inputs[batchnorm::kData].ndim(), 2);

SetDescriptors(param, inputs[batchnorm::kData]);

auto s = ctx.get_stream<gpu>();
MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[batchnorm::kData].type_flag_), DType, {
DType a = 1.0f;
DType b = 0.0f;
if (param.fix_gamma) inputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 1.0f;
if (ctx.is_train) {
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
Globals::Get().io_desc, nullptr, Globals::Get().io_desc, Globals::Get().mean_desc,
nullptr, &workspace_size));
auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormForward");

// If the lock on the auxiliary states is set, then this implies that
// the preceding call is also a `Forward()` call, which further
// indicates that we are in the backward mirroring mode, and therefore
// update to the auxiliary states is disabled. This is done by setting
// the `momentum` to `1` (or `factor` to `0`).
double factor =
((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) &&
Globals::Get().internal_aux_states_lock)
? 0
: (1 - param.momentum);
CUDNN_CALL(cudnnBatchNormalizationForwardTrainingEx(
s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, &a, &b,
Globals::Get().io_desc, inputs[batchnorm::kData].dptr_,
nullptr, nullptr, // zDesc, zData
Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_,
Globals::Get().mean_desc,
inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_,
factor, inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_,
param.eps, outputs[batchnorm::kMean].dptr_, outputs[batchnorm::kVar].dptr_,
nullptr, // activation desc
workspace, workspace_size,
nullptr, 0)); // reserveSpace, reserveSpaceSizeInBytes
} else {
CUDNN_CALL(cudnnBatchNormalizationForwardInference(
s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL, &a, &b,
Globals::Get().io_desc, inputs[batchnorm::kData].dptr_,
Globals::Get().io_desc, outputs[batchnorm::kOut].dptr_,
Globals::Get().mean_desc,
inputs[batchnorm::kGamma].dptr_, inputs[batchnorm::kBeta].dptr_,
inputs[batchnorm::kInMovingMean].dptr_, inputs[batchnorm::kInMovingVar].dptr_,
param.eps));
}
})
// Set the lock on the auxiliary states.
// If the next call to the operator is a `Forward()` call,
// then `momentum` will be set to `1` and hence auxiliary states will not be updated.
Globals::Get().internal_aux_states_lock = true;
}

void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 8);
CHECK_EQ(outputs.size(), 3);
CHECK_EQ(req.size(), 3);

SetDescriptors(param, inputs[3 + batchnorm::kData]);
auto s = ctx.get_stream<gpu>();
size_t workspace_size = 0;
CUDNN_CALL(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
Globals::Get().io_desc, Globals::Get().io_desc, Globals::Get().io_desc, nullptr,
Globals::Get().io_desc, Globals::Get().mean_desc, nullptr, &workspace_size));
auto workspace = ctx.requested[0].get_space_internal(workspace_size, "CudnnBatchNormBackward");
MSHADOW_REAL_TYPE_SWITCH(ParamType(inputs[3 + batchnorm::kData].type_flag_), DType, {
if (param.fix_gamma) inputs[3 + batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 1.0f;
bool grad_add_gamma_beta = req[batchnorm::kGamma] == kAddTo || req[batchnorm::kBeta] == kAddTo;
if (grad_add_gamma_beta) {
if (IsBNWriting(req[batchnorm::kGamma]))
outputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 0.0f;
if (IsBNWriting(req[batchnorm::kBeta]))
outputs[batchnorm::kBeta].FlatTo1D<gpu, DType>(s) = 0.0f;
}
DType a = 1.0f;
DType b = 0.0f;
DType b_add = 1.0f;
const bool global_stats = !ctx.is_train || param.use_global_stats;
CUDNN_CALL(cudnnBatchNormalizationBackwardEx(
s->dnn_handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN,
&a, req[batchnorm::kData] == kAddTo ? &b_add : &b,
&a, grad_add_gamma_beta ? &b_add : &b,
Globals::Get().io_desc, inputs[3 + batchnorm::kData].dptr_,
nullptr, nullptr, // yDesc, yData
Globals::Get().io_desc, inputs[batchnorm::kOut].dptr_,
nullptr, nullptr, // dzDesc, dzData
Globals::Get().io_desc, outputs[batchnorm::kData].dptr_,
Globals::Get().mean_desc,
inputs[3 + batchnorm::kGamma].dptr_, inputs[3 + batchnorm::kBeta].dptr_,
outputs[batchnorm::kGamma].dptr_, outputs[batchnorm::kBeta].dptr_, param.eps,
global_stats ? nullptr : inputs[batchnorm::kMean].dptr_,
global_stats ? nullptr : inputs[batchnorm::kVar].dptr_,
nullptr, // activationDesc
workspace, workspace_size,
nullptr, 0)); // reserveSpace, reserveSpaceSizeInBytes
if (param.fix_gamma) outputs[batchnorm::kGamma].FlatTo1D<gpu, DType>(s) = 0.0f;
})
Globals::Get().internal_aux_states_lock = false;
}

#endif // MXNET_USE_CUDNN == 1
} // namespace op
} // namespace mxnet
56 changes: 56 additions & 0 deletions src/operator/nn/cudnn/cudnn_batch_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2015 by Contributors
* \file cudnn_batch_norm.h
* \brief
* \author Junyuan Xie
*/

#ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_
#define MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_

#include <mxnet/base.h>
#include <vector>
#include "../batch_norm-inl.h"

namespace mxnet {
namespace op {

#if MXNET_USE_CUDNN == 1

STATIC_ASSERT_CUDNN_VERSION_GE(7401);

bool CudnnBatchNormSupports(const BatchNormParam& param, const TBlob& x);

void CudnnBatchNormForward(const BatchNormParam& param, const OpContext& ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

void CudnnBatchNormBackward(const BatchNormParam& param, const OpContext& ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

#endif // MXNET_USE_CUDNN == 1

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NN_CUDNN_CUDNN_BATCH_NORM_H_

0 comments on commit bf50cd4

Please sign in to comment.