From 959b1a3b541943504b9c85d8fa517e5d7dfe0029 Mon Sep 17 00:00:00 2001 From: Justin Luitjens Date: Fri, 2 Sep 2022 14:54:41 -0600 Subject: [PATCH] conv1d optimizations. (#259) Fixed complex math for float/double to generate only FFMAs/DFFMAs. Added ILP Fixed alignment issue Made robust to large signals Added unit tests and benchmark Co-authored-by: jluitjens --- bench/00_transform/conv.cu | 17 +++ include/matx/core/utils.h | 25 +++- include/matx/kernels/conv.cuh | 231 ++++++++++++++++----------------- include/matx/transforms/conv.h | 46 ++++--- test/00_transform/ConvCorr.cu | 48 +++++++ test/include/test_types.h | 3 + 6 files changed, 231 insertions(+), 139 deletions(-) diff --git a/bench/00_transform/conv.cu b/bench/00_transform/conv.cu index 0a2984f8..2bdee6a5 100644 --- a/bench/00_transform/conv.cu +++ b/bench/00_transform/conv.cu @@ -46,3 +46,20 @@ void conv1d_2d_batch(nvbench::state &state, [&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); }); } NVBENCH_BENCH_TYPES(conv1d_2d_batch, NVBENCH_TYPE_AXES(conv_types)); + +template +void conv1d_large(nvbench::state &state, + nvbench::type_list) +{ + auto out = make_tensor({39321704}); + auto at = make_tensor({39321600}); + auto bt = make_tensor({105}); + + out.PrefetchDevice(0); + at.PrefetchDevice(0); + bt.PrefetchDevice(0); + + state.exec( + [&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); }); +} +NVBENCH_BENCH_TYPES(conv1d_large, NVBENCH_TYPE_AXES(conv_types)); diff --git a/include/matx/core/utils.h b/include/matx/core/utils.h index f6352fe4..9f450801 100644 --- a/include/matx/core/utils.h +++ b/include/matx/core/utils.h @@ -65,7 +65,7 @@ template __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ auto madd( const T1 &x, const T2 &y, const T3 &z) { using T4 = decltype(x*y+z); if constexpr (is_complex_v && !is_complex_half_v) { - + using value_type = typename T4::value_type; value_type xr, xi; @@ -109,13 +109,34 @@ __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ auto madd( const T1 &x, const T2 & //__half2 X = make_half2(x.real(), x.imag()); //__half2 Y = make_half2(y.real(), y.imag()); //__half2 Z = make_half2(z.real(), z.imag()); - + const __half2 &X = *reinterpret_cast(&x); const __half2 &Y = *reinterpret_cast(&y); const __half2 &Z = *reinterpret_cast(&z); +#if 1 auto v = __hcmadd(X,Y,Z); return T4(v.x, v.y); +#else + // In theory this could be faster but compiler is not folding broadcast/swap into HFMAs + + __half2 ari = make_half2(X.x, X.y); + // negate and swap supported in hardware sm_8.6+ + __half2 air = make_half2(X.y, __hneg(X.x)); + // broadcast supported in hardware + __half2 brr = make_half2(Y.x, Y.x); + // broadcast supported in hardware + __half2 bii = make_half2(Y.y, Y.y); + __half2 c = Z; + __half2 d; + + // HFMA2 RD, RA.H1_H0, RB.H1_H1, RC.H1_H0 + d = __hfma2(ari, brr, c); + // HFMA2 RD, RB.H0_H0, -RA.H0_NH1, RC.H1_H0 + d = __hfma2(bii, -air, d); + + return T4(d.x, d.y); +#endif } else { return x*y+z; } diff --git a/include/matx/kernels/conv.cuh b/include/matx/kernels/conv.cuh index c1a3495c..8e948011 100644 --- a/include/matx/kernels/conv.cuh +++ b/include/matx/kernels/conv.cuh @@ -13,7 +13,7 @@ #include "matx/core/type_utils.h" #include "matx/core/tensor_utils.h" -#define BLOCK_SIZE_NON_RECURSIVE 1024 +#define CONV1D_ELEMENTS_PER_BLOCK 512 namespace matx { @@ -30,168 +30,161 @@ typedef enum { } matxConvCorrMethod_t; #ifdef __CUDACC__ -template -__launch_bounds__(1024) +template +__launch_bounds__(THREADS) __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter, index_t signal_len, matxConvCorrMode_t mode) { + + /* strategy: + 1 thread per EPT outputs. + Each block produces EPT * THREADS outputs + Full convolution is computed and results are windowed down based on the request + Filter is fully loaded into shared memory + Chunk of signal is loaded into shared memory with filter_len pandding on the negative side. + If out of range then we pad with zeros. + */ static_assert(InType::Rank() == FilterType::Rank()); const int Rank = InType::Rank(); - extern __shared__ float s_exch[]; // Filter + halo + extern __shared__ char s_exch1d[]; // Filter + halo using ftype_strip = typename FilterType::scalar_type; using intype_strip = typename InType::scalar_type; using outtype_strip = typename OutType::scalar_type; - int chunk_idx = blockIdx.y; int batch_idx = blockIdx.x; - int32_t filter_len = d_filter.Size(Rank-1); + uint32_t filter_len = d_filter.Size(Rank-1); + uint32_t full_len = signal_len + filter_len - 1; // All but the last dim will be populated auto bdims = BlockToIdx(d_in, batch_idx, 1); - // Adjustment to keep base shm size as float, but if the filter is complex we - // need to adjust it - constexpr float filt_size_adj = - static_cast(sizeof(ftype_strip)) / sizeof(s_exch[0]); + ftype_strip *s_filter = reinterpret_cast(&s_exch1d[0]); - ftype_strip *s_filter = reinterpret_cast(&s_exch[0]); - intype_strip *s_data; - - // If the data type has a higher alignment type than the filter, we need to - // adjust our shm pointer - if constexpr (std::alignment_of_v < intype_strip >> - std::alignment_of_v) { - s_data = - matx::detail::AlignAddr((uint8_t *)&s_exch[static_cast( - filter_len * filt_size_adj)]); // Start data portion after 2x the - // filter to remove conditionals and - // multiply by 0 - } - else { - s_data = reinterpret_cast(&s_exch[static_cast( - filter_len * - filt_size_adj)]); // Start data portion after 2x the filter to - // remove conditionals and multiply by 0 + size_t filter_bytes = filter_len * sizeof(ftype_strip); + // pad bytes to alignmetn of InType + int align = std::alignment_of_v; + filter_bytes = (filter_bytes + align - 1) / align * align; + + intype_strip *s_data = reinterpret_cast(&s_exch1d[filter_bytes]); + + // load filter + for (uint32_t idx = threadIdx.x; idx < filter_len; idx += THREADS) { + bdims[Rank - 1] = idx; + detail::mapply([&](auto &&...args) { + s_filter[idx] = d_filter.operator()(args...); + }, bdims); } - index_t full_len = signal_len + filter_len - 1; + // number of chunks in the signal, number of output elements / chunk size rounded up + uint32_t num_chunks = (signal_len + filter_len -1 + CONV1D_ELEMENTS_PER_BLOCK - 1) / CONV1D_ELEMENTS_PER_BLOCK; - // This is the location that is written in memory. Note that there will be - // duplicate tids based on this formula, but not all threads write out to - // memory. Some are only there to fetch data, while others both fetch and - // compute output - const int32_t tid = - static_cast(chunk_idx) * (blockDim.x - filter_len + 1) + - threadIdx.x; - int offset = tid - filter_len + 1; + // number of chunks per Y block, rounded up + num_chunks = (num_chunks + gridDim.y - 1) / gridDim.y; - outtype_strip val = 0; +#pragma unroll 1 + for(uint32_t n = 0; n < num_chunks; n++) { + // compute current chunk idx + uint32_t chunk_idx = blockIdx.y + n * gridDim.y; - // Zero out shared memory since it's used later to index into where we want - // 0-valued taps - for (int32_t i = threadIdx.x; i < filter_len + blockDim.x; i += blockDim.x) { - s_data[i] = 0.0; - } + // ensure s_data is consumed from last iteration of chunk loop + if( n > 0 ) + __syncthreads(); - __syncthreads(); + // load signal, pad extra elements with zeros + for (int32_t lidx = threadIdx.x, gidx = chunk_idx * CONV1D_ELEMENTS_PER_BLOCK - filter_len + 1 + threadIdx.x; + gidx < static_cast((chunk_idx+1) * CONV1D_ELEMENTS_PER_BLOCK) ; + gidx += THREADS, lidx += THREADS) { - if (threadIdx.x < filter_len) { - bdims[Rank - 1] = threadIdx.x; - detail::mapply([&](auto &&...args) { - s_filter[threadIdx.x] = d_filter.operator()(args...); - }, bdims); - } + // some elements may be out of range. We set their values to 0. + intype_strip val(0); - __syncthreads(); + if( gidx >= 0 && gidx < signal_len) { + bdims[Rank - 1] = gidx; + detail::mapply([&](auto &&...args) { + val = d_in.operator()(args...); + }, bdims); + } - if (chunk_idx == 0) { - // We want all blocks to process uniformly, so the first block's last few - // threads are idle to match what all other blocks do - s_data[threadIdx.x] = 0; + s_data[lidx] = val; + } + // wait for signal to load __syncthreads(); - // The first block just grabs all the data from the start of the sequence - if (threadIdx.x < signal_len && - (threadIdx.x < blockDim.x - filter_len + 1)) { - bdims[Rank - 1] = threadIdx.x; - detail::mapply([&](auto &&...args) { - s_data[threadIdx.x + filter_len - 1] = d_in.operator()(args...); - }, bdims); - } - } - else if (offset > 0 && offset < signal_len) { - // Each block processes blockDim.x-filt_len+1 samples, but needs to fetch - // all blockDim.x worth - bdims[Rank - 1] = offset; - detail::mapply([&](auto &&...args) { - s_data[threadIdx.x] = d_in.operator()(args...); - }, bdims); - } + // register array for output data + outtype_strip oval[EPT] = {0}; - __syncthreads(); + // Below will use pointer modification instead of offsets to change IMADS into IADS. IMADS go through FMA pipe. - // Even though all threads in the block fetched data, there is only enough - // data in shared memory for blockDim-filt_len+1 to operate on. The rest sit - // idle through this process. - if (tid < full_len && (threadIdx.x < blockDim.x - filter_len + 1)) { -#if 0 -#pragma unroll - for (index_t r = 0; r < filter_len; r++) { - val = val + s_filter[r] * s_data[threadIdx.x + filter_len - 1 - r]; - } -#else + // offset s_data to last element in the filter s_data += threadIdx.x + filter_len - 1; - for (int32_t r = 0; r < filter_len; r++) { -#if 0 - val = val + s_filter[0] * s_data[0]; -#else - val = detail::madd(s_filter[0], s_data[0], val); -#endif - s_data--; + + // for each tap + for(uint32_t f = 0; f < filter_len; f++) { + // load filter value into registers + ftype_strip fval = s_filter[0]; + + // next filter value s_filter++; - } -#endif - if (mode == MATX_C_MODE_FULL) { - bdims[Rank - 1] = tid; - detail::mapply([&](auto &&...args) { - d_out.operator()(args...) = val; - }, bdims); - } - else if (mode == MATX_C_MODE_SAME) { - int start_tid, stop_tid; - if (filter_len & 1) { - start_tid = (filter_len - 1) >> 1; + // register array for signal data + intype_strip ival[EPT]; + // load N elements of the signal into registers + +#pragma unroll + for(uint32_t i = 0; i < EPT; i++) { + ival[i] = s_data[i*THREADS]; } - else { - start_tid = (filter_len >> 1) - 1; + s_data--; // next signal value + + // compute N elements of the convolution +#pragma unroll + for(uint32_t i = 0; i < EPT; i++) { + oval[i] = detail::madd(ival[i], fval, oval[i]); } + } - stop_tid = full_len - (filter_len >> 1) - 1; + // restore shared pointers for next loop + s_filter -= filter_len; + s_data -= (threadIdx.x - 1); - if (tid >= start_tid && tid <= stop_tid) { - bdims[Rank - 1] = tid - start_tid; - detail::mapply([&](auto &&...args) { - d_out.operator()(args...) = val; - }, bdims); + // We have computed the full convolution here. we now need to output the correct range. + // compute output range + uint32_t start; + uint32_t stop; + + if (mode == MATX_C_MODE_FULL) { + start = 0; + stop = full_len - 1; + } else if ( mode == MATX_C_MODE_SAME) { + if( filter_len & 1) { + start = (filter_len - 1) / 2; + } else { + start = filter_len / 2 - 1; } + stop = full_len - filter_len / 2 - 1; + } else { + start = filter_len - 1; + stop = full_len - filter_len; } - else { // Valid - int start_tid, stop_tid; - start_tid = filter_len - 1; - stop_tid = full_len - filter_len; - if (tid >= start_tid && tid <= stop_tid) { - bdims[Rank - 1] = tid - start_tid; +#pragma unroll + for (uint32_t i = 0; i < EPT; i++) { + // index for the computation + uint32_t idx = chunk_idx * CONV1D_ELEMENTS_PER_BLOCK + i * THREADS + threadIdx.x; + // output index is shifted by start + int32_t gidx = idx - start; + + if(idx >= start && idx <= stop) { + bdims[Rank - 1] = gidx; detail::mapply([&](auto &&...args) { - d_out.operator()(args...) = val; - }, bdims); + d_out.operator()(args...) = oval[i]; + }, bdims); } } - } + } // end chunk loop } template diff --git a/include/matx/transforms/conv.h b/include/matx/transforms/conv.h index 8518468d..63af7ad2 100644 --- a/include/matx/transforms/conv.h +++ b/include/matx/transforms/conv.h @@ -54,33 +54,43 @@ inline void matxDirectConv1DInternal(OutputType &o, const InType &i, using strip_filter_t = typename FilterType::scalar_type; using shape_type = typename OutputType::shape_type; MATX_STATIC_ASSERT(OutputType::Rank() == InType::Rank(), matxInvalidDim); - MATX_ASSERT_STR(filter.Size(filter.Rank()-1) < BLOCK_SIZE_NON_RECURSIVE, matxInvalidSize, - "Convolutions are limited to filter lengths < 1024"); + MATX_ASSERT_STR(filter.Size(filter.Rank()-1) < CONV1D_ELEMENTS_PER_BLOCK, matxInvalidSize, + "Convolutions are limited to filter lengths < 1024"); + + MATX_ASSERT_STR(mode != MATX_C_MODE_FULL || o.Size(o.Rank()-1) == i.Size(i.Rank()-1) + filter.Size(filter.Rank()-1) - 1, + matxInvalidSize, "Output size for FULL convolution incorrect"); + MATX_ASSERT_STR(mode != MATX_C_MODE_SAME || o.Size(o.Rank()-1) == i.Size(i.Rank()-1), + matxInvalidSize, "Output size for SAME convolution incorrect"); #ifdef __CUDACC__ - // Scale the filter - size_t filter_shm; - if (sizeof(strip_filter_t) < sizeof(strip_input_t)) { - filter_shm = (filter.Size(filter.Rank()-1) * sizeof(strip_filter_t) + (sizeof(strip_input_t)-1)) / sizeof(strip_input_t) * sizeof(strip_input_t); - } - else { - filter_shm = filter.Size(filter.Rank()-1) * sizeof(strip_filter_t); - } + size_t filter_len = filter.Size(filter.Rank()-1); + size_t signal_len = i.Size(i.Rank()-1); - auto shmsize = filter_shm + sizeof(strip_input_t) * (filter.Size(filter.Rank()-1) + BLOCK_SIZE_NON_RECURSIVE); - + size_t filter_shm = sizeof(strip_filter_t) * filter_len; + size_t signal_shm = sizeof(strip_input_t) * (CONV1D_ELEMENTS_PER_BLOCK + filter_len); + + // align filter size to signal size + int align = std::alignment_of_v; + filter_shm = (filter_shm + align - 1) / align * align; + + size_t shmsize = filter_shm + signal_shm; shape_type sig_len = i.Size(OutputType::Rank() - 1); - float work_per_block = - static_cast(BLOCK_SIZE_NON_RECURSIVE - filter.Size(filter.Rank()-1) + 1); - int num_blocks = static_cast(std::ceil( - static_cast(sig_len + filter.Size(filter.Rank()-1) - 1) / work_per_block)); + int work_per_block = CONV1D_ELEMENTS_PER_BLOCK; + int num_blocks = (int)(sig_len + filter.Size(filter.Rank()-1) + work_per_block -1) / work_per_block; + + // number below was chosen arbitrarily. Cannot be more than 65536. + num_blocks = std::min(num_blocks, 10000); int grid_size = static_cast(TotalSize(i)/i.Size(i.Rank() - 1)); dim3 gsize(grid_size, num_blocks); - Conv1D<<>>( - o, i, filter, sig_len, mode); + constexpr int EPT = 4; + constexpr int THREADS = CONV1D_ELEMENTS_PER_BLOCK / EPT; + static_assert(CONV1D_ELEMENTS_PER_BLOCK % EPT == 0); + + Conv1D<<>>( + o, i, filter, sig_len, mode); #endif } diff --git a/test/00_transform/ConvCorr.cu b/test/00_transform/ConvCorr.cu index 05165a8e..5584bfa1 100644 --- a/test/00_transform/ConvCorr.cu +++ b/test/00_transform/ConvCorr.cu @@ -56,6 +56,10 @@ constexpr index_t c_len1_valid_even = a_len1 - b_len1_even + 1; constexpr index_t c_len1_valid_odd = a_len1 - b_len1_odd + 1; constexpr index_t c_len1_same = a_len1; +constexpr index_t a_len = 8 * 1228800 + 2 * 32768; +constexpr index_t b_len = 209; +constexpr index_t c_len = a_len + b_len - 1; + template class CorrelationConvolutionTest : public ::testing::Test { protected: @@ -114,19 +118,63 @@ protected: float thresh = 0.01f; }; +template +class CorrelationConvolutionLargeTest : public ::testing::Test { +protected: + void SetUp() override + { + CheckTestTypeSupport(); + pb = std::make_unique(); + + // Half precision needs a bit more tolerance when compared to + // fp32 + if constexpr (is_complex_half_v || is_matx_half_v) { + thresh = 0.2f; + } + } + + void TearDown() { pb.reset(); } + + std::unique_ptr pb; + tensor_t av{{a_len}}; + tensor_t bv{{b_len}}; + tensor_t cv{{c_len}}; + float thresh = 0.01f; +}; + template class CorrelationConvolutionTestFloatTypes : public CorrelationConvolutionTest { }; +template +class CorrelationConvolutionLargeTestFloatTypes + : public CorrelationConvolutionLargeTest { +}; + template class CorrelationConvolution2DTestFloatTypes : public CorrelationConvolution2DTest { }; TYPED_TEST_SUITE(CorrelationConvolutionTestFloatTypes, MatXFloatTypes); +TYPED_TEST_SUITE(CorrelationConvolutionLargeTestFloatTypes, MatXFloatNonHalfTypes); TYPED_TEST_SUITE(CorrelationConvolution2DTestFloatTypes, MatXFloatTypes); +// Real/real direct 1D convolution Large +TYPED_TEST(CorrelationConvolutionLargeTestFloatTypes, Direct1DConvolutionLarge) +{ + MATX_ENTER_HANDLER(); + this->pb->template InitTVGenerator("00_transforms", "conv_operators", {a_len, b_len}); + this->pb->RunTVGenerator("conv"); + this->pb->NumpyToTensorView(this->av, "a_op"); + this->pb->NumpyToTensorView(this->bv, "b_op"); + conv1d(this->cv, this->av, this->bv, MATX_C_MODE_FULL, 0); + + MATX_TEST_ASSERT_COMPARE(this->pb, this->cv, "conv_full", this->thresh); + MATX_EXIT_HANDLER(); +} + // Real/real direct 1D convolution TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullEven) { diff --git a/test/include/test_types.h b/test/include/test_types.h index 256e8410..b8731dff 100644 --- a/test/include/test_types.h +++ b/test/include/test_types.h @@ -82,6 +82,9 @@ typedef Types, cuda::std::complex, matx::matxFp16Complex, matx::matxBf16Complex> MatXFloatTypes; +typedef Types, cuda::std::complex> + MatXFloatNonHalfTypes; typedef Types MatXFloatNonComplexTypes; typedef Types MatXFloatHalfTypes;