From f0979e4207d8e61c470f86f7ee0137402330b650 Mon Sep 17 00:00:00 2001 From: Alex Gladkov Date: Fri, 6 Nov 2020 18:53:54 -0800 Subject: [PATCH] conv1d_transpose speedup. (#6840) Improve performance of transposed convolution by avoiding redundant multiplication by zero values from dilated data. Co-authored-by: Ubuntu --- python/tvm/topi/cuda/conv1d_transpose_ncw.py | 75 +++++++++---------- .../python/test_topi_conv1d_transpose_ncw.py | 4 + 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/python/tvm/topi/cuda/conv1d_transpose_ncw.py b/python/tvm/topi/cuda/conv1d_transpose_ncw.py index 1ddbdcca9b36..58f53eab20ac 100644 --- a/python/tvm/topi/cuda/conv1d_transpose_ncw.py +++ b/python/tvm/topi/cuda/conv1d_transpose_ncw.py @@ -65,29 +65,46 @@ def conv1d_transpose_ncw(cfg, data, kernel, stride, padding, out_dtype, output_p out_width = (inp_width - 1) * stride + kernel_size - pad_left - pad_right + output_padding pad_left = kernel_size - 1 - pad_left pad_right = kernel_size - 1 - pad_right + output_padding - dilated_width = stride * (inp_width - 1) + 1 - data = te.compute( - (batch, inp_channels, pad_left + dilated_width + pad_right), + padded_width = pad_left + inp_width + pad_right + + padded_data = te.compute( + (batch, inp_channels, padded_width), lambda n, c, x: tvm.tir.if_then_else( - tvm.tir.all( - x >= pad_left, - x < pad_left + dilated_width, - tvm.tir.indexmod(x - pad_left, stride).equal(0), - ), - data[n, c, tvm.tir.indexdiv(x - pad_left, stride)], + tvm.tir.all(x >= pad_left, x < pad_left + inp_width), + data[n, c, x - pad_left], tvm.tir.const(0.0, "float32"), ), name="data_pad", ) - dc = te.reduce_axis((0, inp_channels), name="dc") - dw = te.reduce_axis((0, kernel_size), name="dw") + padded_kernel = te.compute( + (inp_channels, out_channels, kernel_size + stride - 1), + lambda ci, co, k: tvm.tir.if_then_else( + tvm.tir.all(k < kernel_size), + kernel[ci, co, kernel_size - k - 1], + tvm.tir.const(0.0, "float32"), + ), + name="kernel_pad", + ) + + ci = te.reduce_axis((0, inp_channels), name="ci") + k = te.reduce_axis((0, tvm.tir.indexdiv(kernel_size + stride - 1, stride)), name="k") + border = pad_left * (stride - 1) + + # Skip multiplication by 0 values in the input data inserted when stride is greater then 1. + # During multiplication of kernel by padded data: + # Kernel indices are: 0, 1 * stride, 2 * stride, ..., ceil(kernel_size / stride) plus + # data offset mod stride data_out = te.compute( (batch, out_channels, out_width), - lambda b, c, w: te.sum( - data[b, dc, w + dw].astype(out_dtype) - * kernel[dc, c, kernel_size - 1 - dw].astype(out_dtype), - axis=[dc, dw], + lambda b, co, w: te.sum( + padded_data[b, ci, tvm.tir.indexdiv(border + w + stride - 1, stride) + k].astype( + out_dtype + ) + * padded_kernel[ + ci, co, k * stride + tvm.tir.indexmod(stride - w - border, stride) + ].astype(out_dtype), + axis=[ci, k], ), tag="conv1d_transpose_ncw", ) @@ -118,8 +135,8 @@ def schedule_conv1d_transpose_ncw(cfg, outs): def _callback(op): if op.tag == "conv1d_transpose_ncw": - pad_data = op.input_tensors[0] - kernel = op.input_tensors[1] + padded_data = op.input_tensors[0] + padded_kernel = op.input_tensors[1] conv = op.output(0) ##### space definition begin ##### @@ -139,9 +156,6 @@ def _callback(op): ##### space definition end ##### - if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") @@ -150,10 +164,8 @@ def _callback(op): s[conv].set_scope("local") OL = conv - # create cache stage - s[pad_data].set_scope("shared") - AA = pad_data - WW = s.cache_read(kernel, "shared", [OL]) + s[padded_kernel].compute_inline() + s[padded_data].compute_inline() # tile and bind spatial axes n, f, x = s[output].op.axis @@ -172,9 +184,6 @@ def _callback(op): s[output].bind(tx, te.thread_axis("threadIdx.x")) s[OL].compute_at(s[output], tx) - # number of threads - n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] - n_tx = cfg["tile_x"].size[2] # tile reduction axes n, f, x = s[OL].op.axis @@ -182,18 +191,6 @@ def _callback(op): rco, rcm, rci = cfg["tile_rc"].apply(s, OL, rc) s[OL].reorder(rco, rcm, rx, rci, n, f, x) - s[AA].compute_at(s[OL], rx) - s[WW].compute_at(s[OL], rx) - - # cooperative fetching - for load in [AA, WW]: - n, f, x = s[load].op.axis - fused = s[load].fuse(f, x) - tz, fused = s[load].split(fused, nparts=n_tz) - tx, fused = s[load].split(fused, nparts=n_tx) - s[load].bind(tz, te.thread_axis("threadIdx.y")) - s[load].bind(tx, te.thread_axis("threadIdx.x")) - s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) diff --git a/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py b/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py index c251283f8011..2b8c486b8cd1 100644 --- a/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py +++ b/tests/python/topi/python/test_topi_conv1d_transpose_ncw.py @@ -91,9 +91,13 @@ def test_conv1d_transpose_ncw(): verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 2, 256, (0,)) verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (0,)) verify_conv1d_transpose_ncw(1, 1, 1024, 1, 512, 5, 256, (3,)) + verify_conv1d_transpose_ncw(1, 2, 1024, 1, 128, 128, 0, (0,)) + verify_conv1d_transpose_ncw(1, 1, 1024, 2, 128, 128, 0, (0,)) + verify_conv1d_transpose_ncw(1, 1, 1024, 2, 2, 2, 0, (0,)) verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (0, 3), (0,)) verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (1, 3), (0,)) verify_conv1d_transpose_ncw(1, 1, 10, 1, 5, 1, (2, 3), (0,)) + verify_conv1d_transpose_ncw(1, 257, 128, 1, 512, 128, 256, (0,)) if __name__ == "__main__":