Skip to content

Commit

Permalink
ntt/{kernels.cu,ntt.cuh}: ensure LDE_powers works even for large doma…
Browse files Browse the repository at this point in the history
…in sizes.
  • Loading branch information
dot-asm committed Dec 8, 2023
1 parent 62045e7 commit a6799d9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
30 changes: 15 additions & 15 deletions ntt/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,26 +175,26 @@ fr_t get_intermediate_root(index_t pow, const fr_t (*roots)[WINDOW_SIZE],
}

__launch_bounds__(1024) __global__
void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_blowup, bool bitrev,
void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_domain_size,
uint32_t lg_blowup, bool bitrev,
const fr_t (*gen_powers)[WINDOW_SIZE])
{
index_t idx = threadIdx.x + blockDim.x * (index_t)blockIdx.x;
index_t pow = idx;
fr_t r = d_inout[idx];

if (bitrev) {
size_t domain_size = gridDim.x * (size_t)blockDim.x;
assert((domain_size & (domain_size-1)) == 0);
uint32_t lg_domain_size = 63 - __clzll(domain_size);

pow = bit_rev(idx, lg_domain_size);
}
#if 0
assert(blockDim.x * gridDim.x == blockDim.x * (size_t)gridDim.x);
#endif
size_t domain_size = (size_t)1 << lg_domain_size;
index_t idx = threadIdx.x + blockDim.x * blockIdx.x;

pow <<= lg_blowup;
#pragma unroll 1
for (; idx < domain_size; idx += blockDim.x * gridDim.x) {
fr_t r = d_inout[idx];

r = r * get_intermediate_root(pow, gen_powers);
index_t pow = bitrev ? bit_rev(idx, lg_domain_size) : idx;
pow <<= lg_blowup;
r *= get_intermediate_root(pow, gen_powers);

d_inout[idx] = r;
d_inout[idx] = r;
}
}

__launch_bounds__(1024) __global__
Expand Down
15 changes: 8 additions & 7 deletions ntt/ntt.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,23 @@ protected:

private:
static void LDE_powers(fr_t* inout, bool innt, bool bitrev,
uint32_t lg_domain_size, uint32_t lg_blowup,
uint32_t lg_dsz, uint32_t lg_blowup,
stream_t& stream)
{
size_t domain_size = (size_t)1 << lg_domain_size;
size_t domain_size = (size_t)1 << lg_dsz;
const auto gen_powers =
NTTParameters::all(innt)[stream]->partial_group_gen_powers;

if (domain_size < WARP_SZ)
LDE_distribute_powers<<<1, domain_size, 0, stream>>>
(inout, lg_blowup, bitrev, gen_powers);
else if (domain_size < 512)
(inout, lg_dsz, lg_blowup, bitrev, gen_powers);
else if (lg_dsz < 32)
LDE_distribute_powers<<<domain_size / WARP_SZ, WARP_SZ, 0, stream>>>
(inout, lg_blowup, bitrev, gen_powers);
(inout, lg_dsz, lg_blowup, bitrev, gen_powers);
else
LDE_distribute_powers<<<domain_size / 512, 512, 0, stream>>>
(inout, lg_blowup, bitrev, gen_powers);
LDE_distribute_powers<<<gpu_props(stream).multiProcessorCount, 1024,
0, stream>>>
(inout, lg_dsz, lg_blowup, bitrev, gen_powers);

CUDA_OK(cudaGetLastError());
}
Expand Down

0 comments on commit a6799d9

Please sign in to comment.