diff --git a/src/operator/nn/depthwise_convolution-inl.h b/src/operator/nn/depthwise_convolution-inl.h index 0af8cae51c84..69e6f693b852 100644 --- a/src/operator/nn/depthwise_convolution-inl.h +++ b/src/operator/nn/depthwise_convolution-inl.h @@ -79,80 +79,6 @@ class DepthwiseConvolutionOp { }; // class DepthwiseConvolutionOp namespace depthwise_conv { -namespace cuda { -template -__global__ void __launch_bounds__(1024, 2) -DepthwiseConv2dBackwardFilterKernel(const DepthwiseArgs args, - const DType* out_grad, - const DType* input, - DType* filter_grad) { - const int in_height = args.in_height; - const int in_width = args.in_width; - const int channel = args.in_channel; - const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; - const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; - const int stride_height = args.stride_height; - const int stride_width = args.stride_width; - const int pad_height = args.pad_height; - const int pad_width = args.pad_width; - const int out_height = args.out_height; - const int out_width = args.out_width; - - const int filter_pixels = filter_width * filter_height; - const int out_pixels = out_height * out_width; - const int in_pixels = in_height * in_width; - const int batch_channel_num = channel * args.batch; - - for (int b = blockIdx.x; b < batch_channel_num; b += gridDim.x) { - const int local_batch = b / channel; - const int local_channel = b % channel; - const int filter_offset_temp = local_channel * filter_pixels; - const int out_grad_offset_temp = (local_batch * channel * out_pixels) + - (local_channel * out_pixels); - - // Make sure all threads enter the loop so they get to the enclosed __syncthreads() - for (int out_id = threadIdx.x; - out_id < ROUND_TO_MULTIPLE(out_pixels, - blockDim.x); out_id += blockDim.x) { - const int out_w = out_id % out_width; - const int out_h = (out_id / out_width) % out_height; - const int out_grad_offset = out_grad_offset_temp + (out_h * out_width) + (out_w); - // Set out_g to 0 if the thread would normally have not entered the loop. - const DType out_g = out_id < out_pixels ? ldg(out_grad + out_grad_offset) : DType(0); - - const int in_h_start = out_h * stride_height - pad_height; - const int in_w_start = out_w * stride_width - pad_width; - CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { - const int in_h = in_h_start + f_h; - const int input_offset_temp = (local_batch * channel * in_pixels) + - (local_channel * in_pixels) + (in_h * in_width); - const int filter_offset_h = filter_width * f_h; - - CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { - const int in_w = in_w_start + f_w; - DType partial_grad = DType(0.0f); - if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) { - const int input_offset = input_offset_temp + in_w; - // Set partial_grad to 0 if the thread would normally not have entered the loop. - partial_grad = out_id < out_pixels ? ldg(input + input_offset) * out_g : DType(0); - } - // reduce all valid partial grad in a block - typedef cub::BlockReduce BlockReduceT; - __shared__ typename BlockReduceT::TempStorage temp_storage_reduce; - DType aggregate = BlockReduceT(temp_storage_reduce).Sum(partial_grad, blockDim.x); - if (threadIdx.x == 0) { - DType* addr = filter_grad + f_w + filter_offset_h + filter_offset_temp; - atomicAdd(addr, aggregate); - } - // The presense of __syncthreads() here means all threads must enter enclosing for-loops. - __syncthreads(); - } // for filter_width - } // for filter_height - } // for out_pixels - __syncthreads(); - } // for batch_channel_num -} -} // namespace cuda template void DepthwiseConv2dForwardGpu(mshadow::Stream *stream, @@ -244,6 +170,7 @@ void DepthwiseConv2dBackwardFilterGpu(mshadow::Stream *stream, using namespace mshadow; using namespace mshadow::expr; using namespace tf::depthwise_conv; + using namespace tf::depthwise_conv::cuda; Tensor out_g = out_grad[conv::kOut].get(stream); Tensor in_d = in_data[conv::kData].get(stream); Tensor weight_grad = in_grad[conv::kWeight].get(stream); @@ -258,17 +185,19 @@ void DepthwiseConv2dBackwardFilterGpu(mshadow::Stream *stream, auto s = mshadow::Stream::GetStream(stream); int block_num = std::min(args.out_channel * args.batch, mshadow::cuda::kMaxGridNum); if (args.filter_width == 3 && args.filter_height == 3) { - cuda::DepthwiseConv2dBackwardFilterKernel + DepthwiseConv2dBackwardFilterKernel <<>>(args, out_g.dptr_, in_d.dptr_, - weight_grad.dptr_); + weight_grad.dptr_, + num_out_grad); } else { - cuda::DepthwiseConv2dBackwardFilterKernel + DepthwiseConv2dBackwardFilterKernel <<>>(args, out_g.dptr_, in_d.dptr_, - weight_grad.dptr_); + weight_grad.dptr_, + num_out_grad); } MSHADOW_CUDA_POST_KERNEL_CHECK(DepthwiseConv2dBackwardFilterKernel); } diff --git a/src/operator/nn/depthwise_convolution_tf.cuh b/src/operator/nn/depthwise_convolution_tf.cuh index e4dfd8292d2d..defcfed844a5 100644 --- a/src/operator/nn/depthwise_convolution_tf.cuh +++ b/src/operator/nn/depthwise_convolution_tf.cuh @@ -384,6 +384,96 @@ DepthwiseConv2dBackwardDataKernel(const DepthwiseArgs args, } } +// A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. +template +__global__ void __launch_bounds__(640, 2) +DepthwiseConv2dBackwardFilterKernel(const DepthwiseArgs args, + const DType* out_backprop, + const DType* input, + DType* filter_backprop, + int num_out_backprop) { + const int in_channel = args.in_channel; + const int in_height = args.in_height; + const int in_width = args.in_width; + const int filter_height = kFilterHeight > 0 ? kFilterHeight : args.filter_height; + const int filter_width = kFilterWidth > 0 ? kFilterWidth : args.filter_width; + const int stride_height = args.stride_height; + const int stride_width = args.stride_width; + const int pad_height = args.pad_height; + const int pad_width = args.pad_width; + const int out_channel = args.out_channel; + const int out_height = args.out_height; + const int out_width = args.out_width; + + CUDA_KERNEL_LOOP(thread_id, num_out_backprop) { + // Compute the indexes of this thread in the output. + const int out_w = thread_id % out_width; + const int out_h = (thread_id / out_width) % out_height; + const int out_c = (thread_id / out_width / out_height) % out_channel; + const int out_b = thread_id / out_width / out_height / out_channel; + const int in_c = out_c; + + // Decide if all input is valid, if yes, we can skip the boundary checks + // for each input. + const int in_row_start = out_h * stride_height - pad_height; + const int in_col_start = out_w * stride_width - pad_width; + const int in_row_end = in_row_start + filter_height; + const int in_col_end = in_col_start + filter_width; + + const int out_backprop_offset = + (out_b * out_channel * out_height * out_width) + + (out_c * out_height * out_width) + (out_h * out_width) + + (out_w); + + const DType out_bp = ldg(out_backprop + out_backprop_offset); + if (in_row_start >= 0 && in_col_start >= 0 && + in_row_end < in_height && in_col_end < in_width) { + CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { + const int in_row = in_row_start + f_h; + // Avoid repeated computation. + const int input_offset_temp = + (out_b * in_channel * in_height * in_width) + + (in_c * in_height * in_width) + (in_row * in_width); + + CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { + const int in_col = in_col_start + f_w; + const int input_offset = input_offset_temp + in_col; + DType partial_sum = ldg(input + input_offset) * out_bp; + DType* addr = filter_backprop + (in_c + in_channel * (f_w + filter_width * f_h)); + atomicAdd(addr, partial_sum); + } + } + } else { + CUDA_UNROLL for (int f_h = 0; f_h < filter_height; ++f_h) { + const int in_row = in_row_start + f_h; + // Avoid repeated computation. + const int input_offset_temp = + (out_b * in_channel * in_height * in_width) + + (in_c * in_height * in_width) + (in_row * in_width); + CUDA_UNROLL for (int f_w = 0; f_w < filter_width; ++f_w) { + const int in_col = in_col_start + f_w; + const int addr_temp = filter_width * f_h; + + if (in_row >= 0 && in_row < in_height && in_col >= 0 && in_col < in_width) { + const int input_offset = input_offset_temp + in_col; + DType partial_sum = ldg(input + input_offset) * out_bp; + DType* addr = filter_backprop + (in_c + in_channel * (f_w + addr_temp)); + // Potentially many threads can add to the same address so we have + // to use atomic add here. + // TODO(jmchen): If atomic add turns out to be slow, we can: + // 1. allocate multiple buffers for the gradients (one for each + // example in a batch, for example). This can reduce the + // contention on the destination; 2. Have each thread compute one + // gradient for an element in the filters. This should work well + // when the input depth is big and filter size is not too small. + atomicAdd(addr, partial_sum); + } + } + } + } + } +} + // CUDA kernel to compute the depthwise convolution backward w.r.t. filter in // NCHW format, tailored for small images up to 32x32. Only use this kernel if // CanLaunchDepthwiseConv2dGPUSmall(args) returns true.