Skip to content

Commit

Permalink
ntt/{kernels.cu,kernels/*}: switch to shfl_bfly() method and clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jan 2, 2024
1 parent e5add3f commit c332cc9
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 29 deletions.
25 changes: 0 additions & 25 deletions ntt/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,6 @@

#include <cooperative_groups.h>

#ifdef __CUDA_ARCH__
__device__ __forceinline__
void shfl_bfly(fr_t& r, int laneMask)
{
#pragma unroll
for (int iter = 0; iter < r.len(); iter++)
r[iter] = __shfl_xor_sync(0xFFFFFFFF, r[iter], laneMask);
}
#endif

__device__ __forceinline__
void shfl_bfly(index_t& index, int laneMask)
{
index = __shfl_xor_sync(0xFFFFFFFF, index, laneMask);
}

template<typename T>
__device__ __forceinline__
void swap(T& u1, T& u2)
{
T temp = u1;
u1 = u2;
u2 = temp;
}

template<typename T>
__device__ __forceinline__
T bit_rev(T i, unsigned int nbits)
Expand Down
2 changes: 1 addition & 1 deletion ntt/kernels/ct_mixed_radix_narrow.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size,
for (int z = 0; z < z_count; z++) {
fr_t t = fr_t::csel(r[1][z], r[0][z], pos);

shfl_bfly(t, laneMask);
t.shfl_bfly(laneMask);

r[0][z] = fr_t::csel(t, r[0][z], !pos);
r[1][z] = fr_t::csel(t, r[1][z], pos);
Expand Down
2 changes: 1 addition & 1 deletion ntt/kernels/ct_mixed_radix_wide.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size,

#ifdef __CUDA_ARCH__
fr_t x = fr_t::csel(r1, r0, pos);
shfl_bfly(x, laneMask);
x.shfl_bfly(laneMask);
r0 = fr_t::csel(x, r0, !pos);
r1 = fr_t::csel(x, r1, pos);
#endif
Expand Down
2 changes: 1 addition & 1 deletion ntt/kernels/gs_mixed_radix_narrow.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size,
#ifdef __CUDA_ARCH__
t = fr_t::csel(r[1][z], r[0][z], pos);

shfl_bfly(t, laneMask);
t.shfl_bfly(laneMask);

r[0][z] = fr_t::csel(t, r[0][z], !pos);
r[1][z] = fr_t::csel(t, r[1][z], pos);
Expand Down
2 changes: 1 addition & 1 deletion ntt/kernels/gs_mixed_radix_wide.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size,
bool pos = rank < laneMask;
#ifdef __CUDA_ARCH__
t = fr_t::csel(r1, r0, pos);
shfl_bfly(t, laneMask);
t.shfl_bfly(laneMask);
r0 = fr_t::csel(t, r0, !pos);
r1 = fr_t::csel(t, r1, pos);
#endif
Expand Down

0 comments on commit c332cc9

Please sign in to comment.