Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fast cuDNN BatchNorm NHWC kernels support (#20615)
Browse files Browse the repository at this point in the history
* Fast cuDNN NHWC kernels support

* Fix lint errors

* Get rid of a warning

* Remove CuDNNBatchNorm from AMP lists

Co-authored-by: Vladimir Cherepanov <[email protected]>
  • Loading branch information
mk-61 and Vladimir Cherepanov authored Sep 30, 2021
1 parent f25b92e commit 23af413
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 463 deletions.
5 changes: 0 additions & 5 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,6 @@
'zeros_like',
]

if Features().is_enabled('CUDNN'):
FP16_FP32_FUNCS.extend([
'CuDNNBatchNorm',
])

# Functions that have to be cast to FP32 due to possible
# overflows
FP32_FUNCS = [
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,11 @@ then set ``gamma`` to 1 and its gradient to 0.
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
#endif
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
.add_argument("beta", "NDArray-or-Symbol", "beta array")
Expand Down
32 changes: 7 additions & 25 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#define ADDTO_BETA_FLAG (1 << 8)

#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
#include "./cudnn/cudnn_batch_norm.h"
#endif

#include "../../../include/mxnet/tensor_blob.h"
Expand Down Expand Up @@ -935,11 +935,6 @@ static void BatchNormalizationBackward(mshadow::Stream<gpu>* s,
(flags & IS_TRAINING_FLAG) != 0 && (flags & USE_GLOBAL_STATS_FLAG) == 0;

if (is_train_and_not_global_stats) {
#ifdef NDEBUG
constexpr bool SMALLER_THREADS = false;
#else
constexpr bool SMALLER_THREADS = true;
#endif
dim3 blocks(gradOutput.ChannelCount());
dim3 threads(batchnorm::cuda::getNumThreads(gradOutput.InnerSize()));
BatchNormalizationBackwardKernel<DType, AccReal, DeviceTensor1, batchnorm::BNTensor3<DType>>
Expand Down Expand Up @@ -1104,19 +1099,6 @@ void BatchNormBackwardImpl(mshadow::Stream<gpu>* stream,
MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
}

#if MXNET_USE_CUDNN == 1
template <typename DType>
static CuDNNBatchNormOp<DType>& GetCuDNNOp(const BatchNormParam& param) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local CuDNNBatchNormOp<DType> op;
#else
static MX_THREAD_LOCAL CuDNNBatchNormOp<DType> op;
#endif
op.Init(param);
return op;
}
#endif

template <>
void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -1132,9 +1114,9 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(
dtype, DType, { GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states); })
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[batchnorm::kData])) {
CudnnBatchNormForward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
Expand All @@ -1160,9 +1142,9 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1
if (!param.use_global_stats && !param.cudnn_off) {
MSHADOW_REAL_TYPE_SWITCH(
dtype, DType, { GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs); })
if (!param.use_global_stats && !param.cudnn_off &&
CudnnBatchNormSupports(param, inputs[3 + batchnorm::kData])) {
CudnnBatchNormBackward(param, ctx, inputs, req, outputs);
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
BatchNormBackward<gpu, DType, AccReal>(ctx, param, inputs, req, outputs);
Expand Down
307 changes: 0 additions & 307 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h

This file was deleted.

Loading

0 comments on commit 23af413

Please sign in to comment.