diff --git a/ntt/kernels.cu b/ntt/kernels.cu index 3d78821..da000ee 100644 --- a/ntt/kernels.cu +++ b/ntt/kernels.cu @@ -175,26 +175,26 @@ fr_t get_intermediate_root(index_t pow, const fr_t (*roots)[WINDOW_SIZE], } __launch_bounds__(1024) __global__ -void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_blowup, bool bitrev, +void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_domain_size, + uint32_t lg_blowup, bool bitrev, const fr_t (*gen_powers)[WINDOW_SIZE]) { - index_t idx = threadIdx.x + blockDim.x * (index_t)blockIdx.x; - index_t pow = idx; - fr_t r = d_inout[idx]; - - if (bitrev) { - size_t domain_size = gridDim.x * (size_t)blockDim.x; - assert((domain_size & (domain_size-1)) == 0); - uint32_t lg_domain_size = 63 - __clzll(domain_size); - - pow = bit_rev(idx, lg_domain_size); - } +#if 0 + assert(blockDim.x * gridDim.x == blockDim.x * (size_t)gridDim.x); +#endif + size_t domain_size = (size_t)1 << lg_domain_size; + index_t idx = threadIdx.x + blockDim.x * blockIdx.x; - pow <<= lg_blowup; + #pragma unroll 1 + for (; idx < domain_size; idx += blockDim.x * gridDim.x) { + fr_t r = d_inout[idx]; - r = r * get_intermediate_root(pow, gen_powers); + index_t pow = bitrev ? bit_rev(idx, lg_domain_size) : idx; + pow <<= lg_blowup; + r *= get_intermediate_root(pow, gen_powers); - d_inout[idx] = r; + d_inout[idx] = r; + } } __launch_bounds__(1024) __global__ diff --git a/ntt/ntt.cuh b/ntt/ntt.cuh index ca4fcd5..f68070e 100644 --- a/ntt/ntt.cuh +++ b/ntt/ntt.cuh @@ -57,22 +57,23 @@ protected: private: static void LDE_powers(fr_t* inout, bool innt, bool bitrev, - uint32_t lg_domain_size, uint32_t lg_blowup, + uint32_t lg_dsz, uint32_t lg_blowup, stream_t& stream) { - size_t domain_size = (size_t)1 << lg_domain_size; + size_t domain_size = (size_t)1 << lg_dsz; const auto gen_powers = NTTParameters::all(innt)[stream]->partial_group_gen_powers; if (domain_size < WARP_SZ) LDE_distribute_powers<<<1, domain_size, 0, stream>>> - (inout, lg_blowup, bitrev, gen_powers); - else if (domain_size < 512) + (inout, lg_dsz, lg_blowup, bitrev, gen_powers); + else if (lg_dsz < 32) LDE_distribute_powers<<>> - (inout, lg_blowup, bitrev, gen_powers); + (inout, lg_dsz, lg_blowup, bitrev, gen_powers); else - LDE_distribute_powers<<>> - (inout, lg_blowup, bitrev, gen_powers); + LDE_distribute_powers<<>> + (inout, lg_dsz, lg_blowup, bitrev, gen_powers); CUDA_OK(cudaGetLastError()); }