Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
luitjens committed Sep 2, 2022
1 parent 0b2a54d commit b7f8ace
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
13 changes: 5 additions & 8 deletions include/matx/kernels/conv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
// pad bytes to alignmetn of InType
int align = std::alignment_of_v<intype_strip>;
filter_bytes = (filter_bytes + align - 1) / align * align;

intype_strip *s_data = reinterpret_cast<intype_strip*>(&s_exch1d[filter_bytes]);

// load filter
Expand All @@ -87,22 +87,20 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
for(uint32_t n = 0; n < num_chunks; n++) {
// compute current chunk idx
uint32_t chunk_idx = blockIdx.y + n * gridDim.y;

// ensure s_data is consumed from last iteration of chunk loop
if( n > 0 )
__syncthreads();

//printf("start: %d, end: %d\n", chunk_idx * CONV1D_ELEMENTS_PER_BLOCK - filter_len + 1 + threadIdx.x, (chunk_idx+1) * CONV1D_ELEMENTS_PER_BLOCK - filter_len + 1);
// load signal, pad extra elements with zeros
//TODO do we want the - filter_len + 1 here?
for (int32_t lidx = threadIdx.x, gidx = chunk_idx * CONV1D_ELEMENTS_PER_BLOCK - filter_len + 1 + threadIdx.x;
gidx < static_cast<int32_t>((chunk_idx+1) * CONV1D_ELEMENTS_PER_BLOCK) ;
gidx += THREADS, lidx += THREADS) {
gidx < static_cast<int32_t>((chunk_idx+1) * CONV1D_ELEMENTS_PER_BLOCK) ;
gidx += THREADS, lidx += THREADS) {

// some elements may be out of range. We set their values to 0.
intype_strip val(0);

if( gidx >= 0 && gidx < signal_len) { //TODO do I need upper check?
if( gidx >= 0 && gidx < signal_len) {
bdims[Rank - 1] = gidx;
detail::mapply([&](auto &&...args) {
val = d_in.operator()(args...);
Expand All @@ -123,7 +121,6 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
// offset s_data to last element in the filter
s_data += threadIdx.x + filter_len - 1;

// TODO unroll for ILP?
// for each tap
for(uint32_t f = 0; f < filter_len; f++) {
// load filter value into registers
Expand Down
6 changes: 3 additions & 3 deletions include/matx/transforms/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ inline void matxDirectConv1DInternal(OutputType &o, const InType &i,
"Convolutions are limited to filter lengths < 1024");

MATX_ASSERT_STR(mode != MATX_C_MODE_FULL || o.Size(o.Rank()-1) == i.Size(i.Rank()-1) + filter.Size(filter.Rank()-1) - 1,
matxInvalidSize, "Output size for FULL convolution incorrect");
matxInvalidSize, "Output size for FULL convolution incorrect");
MATX_ASSERT_STR(mode != MATX_C_MODE_SAME || o.Size(o.Rank()-1) == i.Size(i.Rank()-1),
matxInvalidSize, "Output size for SAME convolution incorrect");
matxInvalidSize, "Output size for SAME convolution incorrect");

#ifdef __CUDACC__
size_t filter_len = filter.Size(filter.Rank()-1);
Expand All @@ -72,7 +72,7 @@ inline void matxDirectConv1DInternal(OutputType &o, const InType &i,
// align filter size to signal size
int align = std::alignment_of_v<InType>;
filter_shm = (filter_shm + align - 1) / align * align;

size_t shmsize = filter_shm + signal_shm;

shape_type sig_len = i.Size(OutputType::Rank() - 1);
Expand Down

0 comments on commit b7f8ace

Please sign in to comment.