Skip to content

Commit

Permalink
conv1d optimizations. (#259)
Browse files Browse the repository at this point in the history
Fixed complex math for float/double to generate only FFMAs/DFFMAs.
Added ILP
Fixed alignment issue
Made robust to large signals
Added unit tests and benchmark

Co-authored-by: jluitjens <[email protected]>
  • Loading branch information
luitjens and luitjens authored Sep 2, 2022
1 parent f021d5c commit 959b1a3
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 139 deletions.
17 changes: 17 additions & 0 deletions bench/00_transform/conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,20 @@ void conv1d_2d_batch(nvbench::state &state,
[&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
}
NVBENCH_BENCH_TYPES(conv1d_2d_batch, NVBENCH_TYPE_AXES(conv_types));

template <typename ValueType>
void conv1d_large(nvbench::state &state,
nvbench::type_list<ValueType>)
{
auto out = make_tensor<ValueType>({39321704});
auto at = make_tensor<ValueType>({39321600});
auto bt = make_tensor<ValueType>({105});

out.PrefetchDevice(0);
at.PrefetchDevice(0);
bt.PrefetchDevice(0);

state.exec(
[&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
}
NVBENCH_BENCH_TYPES(conv1d_large, NVBENCH_TYPE_AXES(conv_types));
25 changes: 23 additions & 2 deletions include/matx/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ template <typename T1, typename T2, typename T3>
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ auto madd( const T1 &x, const T2 &y, const T3 &z) {
using T4 = decltype(x*y+z);
if constexpr (is_complex_v<T4> && !is_complex_half_v<T4>) {

using value_type = typename T4::value_type;

value_type xr, xi;
Expand Down Expand Up @@ -109,13 +109,34 @@ __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ auto madd( const T1 &x, const T2 &
//__half2 X = make_half2(x.real(), x.imag());
//__half2 Y = make_half2(y.real(), y.imag());
//__half2 Z = make_half2(z.real(), z.imag());

const __half2 &X = *reinterpret_cast<const __half2*>(&x);
const __half2 &Y = *reinterpret_cast<const __half2*>(&y);
const __half2 &Z = *reinterpret_cast<const __half2*>(&z);

#if 1
auto v = __hcmadd(X,Y,Z);
return T4(v.x, v.y);
#else
// In theory this could be faster but compiler is not folding broadcast/swap into HFMAs
__half2 ari = make_half2(X.x, X.y);
// negate and swap supported in hardware sm_8.6+
__half2 air = make_half2(X.y, __hneg(X.x));
// broadcast supported in hardware
__half2 brr = make_half2(Y.x, Y.x);
// broadcast supported in hardware
__half2 bii = make_half2(Y.y, Y.y);
__half2 c = Z;
__half2 d;
// HFMA2 RD, RA.H1_H0, RB.H1_H1, RC.H1_H0
d = __hfma2(ari, brr, c);
// HFMA2 RD, RB.H0_H0, -RA.H0_NH1, RC.H1_H0
d = __hfma2(bii, -air, d);
return T4(d.x, d.y);
#endif
} else {
return x*y+z;
}
Expand Down
231 changes: 112 additions & 119 deletions include/matx/kernels/conv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "matx/core/type_utils.h"
#include "matx/core/tensor_utils.h"

#define BLOCK_SIZE_NON_RECURSIVE 1024
#define CONV1D_ELEMENTS_PER_BLOCK 512

namespace matx {

Expand All @@ -30,168 +30,161 @@ typedef enum {
} matxConvCorrMethod_t;

#ifdef __CUDACC__
template <typename OutType, typename InType, typename FilterType>
__launch_bounds__(1024)
template <int THREADS, int EPT, typename OutType, typename InType, typename FilterType>
__launch_bounds__(THREADS)
__global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
index_t signal_len,
matxConvCorrMode_t mode)
{

/* strategy:
1 thread per EPT outputs.
Each block produces EPT * THREADS outputs
Full convolution is computed and results are windowed down based on the request
Filter is fully loaded into shared memory
Chunk of signal is loaded into shared memory with filter_len pandding on the negative side.
If out of range then we pad with zeros.
*/
static_assert(InType::Rank() == FilterType::Rank());

const int Rank = InType::Rank();

extern __shared__ float s_exch[]; // Filter + halo
extern __shared__ char s_exch1d[]; // Filter + halo
using ftype_strip = typename FilterType::scalar_type;
using intype_strip = typename InType::scalar_type;
using outtype_strip = typename OutType::scalar_type;
int chunk_idx = blockIdx.y;
int batch_idx = blockIdx.x;
int32_t filter_len = d_filter.Size(Rank-1);
uint32_t filter_len = d_filter.Size(Rank-1);
uint32_t full_len = signal_len + filter_len - 1;

// All but the last dim will be populated
auto bdims = BlockToIdx(d_in, batch_idx, 1);

// Adjustment to keep base shm size as float, but if the filter is complex we
// need to adjust it
constexpr float filt_size_adj =
static_cast<float>(sizeof(ftype_strip)) / sizeof(s_exch[0]);
ftype_strip *s_filter = reinterpret_cast<ftype_strip *>(&s_exch1d[0]);

ftype_strip *s_filter = reinterpret_cast<ftype_strip *>(&s_exch[0]);
intype_strip *s_data;

// If the data type has a higher alignment type than the filter, we need to
// adjust our shm pointer
if constexpr (std::alignment_of_v < intype_strip >>
std::alignment_of_v<ftype_strip>) {
s_data =
matx::detail::AlignAddr<intype_strip>((uint8_t *)&s_exch[static_cast<int32_t>(
filter_len * filt_size_adj)]); // Start data portion after 2x the
// filter to remove conditionals and
// multiply by 0
}
else {
s_data = reinterpret_cast<intype_strip *>(&s_exch[static_cast<int32_t>(
filter_len *
filt_size_adj)]); // Start data portion after 2x the filter to
// remove conditionals and multiply by 0
size_t filter_bytes = filter_len * sizeof(ftype_strip);
// pad bytes to alignmetn of InType
int align = std::alignment_of_v<intype_strip>;
filter_bytes = (filter_bytes + align - 1) / align * align;

intype_strip *s_data = reinterpret_cast<intype_strip*>(&s_exch1d[filter_bytes]);

// load filter
for (uint32_t idx = threadIdx.x; idx < filter_len; idx += THREADS) {
bdims[Rank - 1] = idx;
detail::mapply([&](auto &&...args) {
s_filter[idx] = d_filter.operator()(args...);
}, bdims);
}

index_t full_len = signal_len + filter_len - 1;
// number of chunks in the signal, number of output elements / chunk size rounded up
uint32_t num_chunks = (signal_len + filter_len -1 + CONV1D_ELEMENTS_PER_BLOCK - 1) / CONV1D_ELEMENTS_PER_BLOCK;

// This is the location that is written in memory. Note that there will be
// duplicate tids based on this formula, but not all threads write out to
// memory. Some are only there to fetch data, while others both fetch and
// compute output
const int32_t tid =
static_cast<index_t>(chunk_idx) * (blockDim.x - filter_len + 1) +
threadIdx.x;
int offset = tid - filter_len + 1;
// number of chunks per Y block, rounded up
num_chunks = (num_chunks + gridDim.y - 1) / gridDim.y;

outtype_strip val = 0;
#pragma unroll 1
for(uint32_t n = 0; n < num_chunks; n++) {
// compute current chunk idx
uint32_t chunk_idx = blockIdx.y + n * gridDim.y;

// Zero out shared memory since it's used later to index into where we want
// 0-valued taps
for (int32_t i = threadIdx.x; i < filter_len + blockDim.x; i += blockDim.x) {
s_data[i] = 0.0;
}
// ensure s_data is consumed from last iteration of chunk loop
if( n > 0 )
__syncthreads();

__syncthreads();
// load signal, pad extra elements with zeros
for (int32_t lidx = threadIdx.x, gidx = chunk_idx * CONV1D_ELEMENTS_PER_BLOCK - filter_len + 1 + threadIdx.x;
gidx < static_cast<int32_t>((chunk_idx+1) * CONV1D_ELEMENTS_PER_BLOCK) ;
gidx += THREADS, lidx += THREADS) {

if (threadIdx.x < filter_len) {
bdims[Rank - 1] = threadIdx.x;
detail::mapply([&](auto &&...args) {
s_filter[threadIdx.x] = d_filter.operator()(args...);
}, bdims);
}
// some elements may be out of range. We set their values to 0.
intype_strip val(0);

__syncthreads();
if( gidx >= 0 && gidx < signal_len) {
bdims[Rank - 1] = gidx;
detail::mapply([&](auto &&...args) {
val = d_in.operator()(args...);
}, bdims);
}

if (chunk_idx == 0) {
// We want all blocks to process uniformly, so the first block's last few
// threads are idle to match what all other blocks do
s_data[threadIdx.x] = 0;
s_data[lidx] = val;
}

// wait for signal to load
__syncthreads();

// The first block just grabs all the data from the start of the sequence
if (threadIdx.x < signal_len &&
(threadIdx.x < blockDim.x - filter_len + 1)) {
bdims[Rank - 1] = threadIdx.x;
detail::mapply([&](auto &&...args) {
s_data[threadIdx.x + filter_len - 1] = d_in.operator()(args...);
}, bdims);
}
}
else if (offset > 0 && offset < signal_len) {
// Each block processes blockDim.x-filt_len+1 samples, but needs to fetch
// all blockDim.x worth
bdims[Rank - 1] = offset;
detail::mapply([&](auto &&...args) {
s_data[threadIdx.x] = d_in.operator()(args...);
}, bdims);
}
// register array for output data
outtype_strip oval[EPT] = {0};

__syncthreads();
// Below will use pointer modification instead of offsets to change IMADS into IADS. IMADS go through FMA pipe.

// Even though all threads in the block fetched data, there is only enough
// data in shared memory for blockDim-filt_len+1 to operate on. The rest sit
// idle through this process.
if (tid < full_len && (threadIdx.x < blockDim.x - filter_len + 1)) {
#if 0
#pragma unroll
for (index_t r = 0; r < filter_len; r++) {
val = val + s_filter[r] * s_data[threadIdx.x + filter_len - 1 - r];
}
#else
// offset s_data to last element in the filter
s_data += threadIdx.x + filter_len - 1;
for (int32_t r = 0; r < filter_len; r++) {
#if 0
val = val + s_filter[0] * s_data[0];
#else
val = detail::madd(s_filter[0], s_data[0], val);
#endif
s_data--;

// for each tap
for(uint32_t f = 0; f < filter_len; f++) {
// load filter value into registers
ftype_strip fval = s_filter[0];

// next filter value
s_filter++;
}
#endif

if (mode == MATX_C_MODE_FULL) {
bdims[Rank - 1] = tid;
detail::mapply([&](auto &&...args) {
d_out.operator()(args...) = val;
}, bdims);
}
else if (mode == MATX_C_MODE_SAME) {
int start_tid, stop_tid;
if (filter_len & 1) {
start_tid = (filter_len - 1) >> 1;
// register array for signal data
intype_strip ival[EPT];
// load N elements of the signal into registers

#pragma unroll
for(uint32_t i = 0; i < EPT; i++) {
ival[i] = s_data[i*THREADS];
}
else {
start_tid = (filter_len >> 1) - 1;
s_data--; // next signal value

// compute N elements of the convolution
#pragma unroll
for(uint32_t i = 0; i < EPT; i++) {
oval[i] = detail::madd(ival[i], fval, oval[i]);
}
}

stop_tid = full_len - (filter_len >> 1) - 1;
// restore shared pointers for next loop
s_filter -= filter_len;
s_data -= (threadIdx.x - 1);

if (tid >= start_tid && tid <= stop_tid) {
bdims[Rank - 1] = tid - start_tid;
detail::mapply([&](auto &&...args) {
d_out.operator()(args...) = val;
}, bdims);
// We have computed the full convolution here. we now need to output the correct range.
// compute output range
uint32_t start;
uint32_t stop;

if (mode == MATX_C_MODE_FULL) {
start = 0;
stop = full_len - 1;
} else if ( mode == MATX_C_MODE_SAME) {
if( filter_len & 1) {
start = (filter_len - 1) / 2;
} else {
start = filter_len / 2 - 1;
}
stop = full_len - filter_len / 2 - 1;
} else {
start = filter_len - 1;
stop = full_len - filter_len;
}
else { // Valid
int start_tid, stop_tid;
start_tid = filter_len - 1;
stop_tid = full_len - filter_len;

if (tid >= start_tid && tid <= stop_tid) {
bdims[Rank - 1] = tid - start_tid;
#pragma unroll
for (uint32_t i = 0; i < EPT; i++) {
// index for the computation
uint32_t idx = chunk_idx * CONV1D_ELEMENTS_PER_BLOCK + i * THREADS + threadIdx.x;
// output index is shifted by start
int32_t gidx = idx - start;

if(idx >= start && idx <= stop) {
bdims[Rank - 1] = gidx;
detail::mapply([&](auto &&...args) {
d_out.operator()(args...) = val;
}, bdims);
d_out.operator()(args...) = oval[i];
}, bdims);
}
}
}
} // end chunk loop
}

template <typename OutType, typename InType, typename FilterType>
Expand Down
Loading

0 comments on commit 959b1a3

Please sign in to comment.