Skip to content

Commit

Permalink
add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 1, 2023
1 parent 254b4a0 commit c557a36
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions onnxruntime/core/providers/cuda/math/softmax_blockwise_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,25 @@ struct SumExpFloat {
const AccumT max_k;
};

// One block has N(warps_per_block) warps, one warp has M(WARP_SIZE) threads.
// 1. All the threads in one block read data into shared memory.
// 2. Reduce all data to the first warp. Only the first N threads of warp-0 are used. thread-0 computes data in warp-0 and
// writes the result into the location of data0, thread-1 computes data in warp-1 and writes the result into the location of data1.
// __syncwarp(mask) is necessary here to make sure thread-1,...N will delay writing data into warp-0 until thread-0
// has finished reading data from warp-0.
// Shared memory
// -----------------------------------------------------------------------------------------------------------------------
// | data0 | data1 | data2 | .... | dataM | ... | dataM*2 | ... |
// -----------------------------------------------------------------------------------------------------------------------
// | | | |
// -------------------warp-0----------------------------------warp-1----------------------------------warp-1--------------
// TODO: ROCm doesn't support __syncwarp() now, we need another implementation to make sure read before write.
// 3. Reduce all data to the first thread of warp-0.

template <template <typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal) {
__device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val,
const Reduction<AccumT>& r,
AccumT defaultVal) {
// To avoid RaW races from chaining blockReduce calls together, we need a sync here
__syncthreads();

Expand Down Expand Up @@ -125,12 +139,11 @@ blockReduce(AccumT* smem, AccumT val,
}

template <template <typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal) {
__device__ __forceinline__ AccumT ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal) {
using LoadT = aligned_vector<T, ILP>;
AccumT threadVal = defaultVal;
int offset = threadIdx.x;
Expand Down Expand Up @@ -174,13 +187,11 @@ ilpReduce(int shift,
* This will apply the Epilogue with vectorized reads & writes when input & output have the same shift
*/
template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__device__ __forceinline__ void
WriteFpropResultsVectorized(
int size,
const int shift,
scalar_t* input,
outscalar_t* output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
__device__ __forceinline__ void WriteFpropResultsVectorized(int size,
const int shift,
scalar_t* input,
outscalar_t* output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
using LoadT = aligned_vector<scalar_t, ILP>;
using StoreT = aligned_vector<outscalar_t, ILP>;

Expand Down Expand Up @@ -232,12 +243,10 @@ WriteFpropResultsVectorized(
* This will apply the Epilogue with non-vectrorized reads & writes for the general case
*/
template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template <typename, typename, typename> class Epilogue>
__device__ __forceinline__ void
WriteFpropResults(
int classes,
scalar_t* input,
outscalar_t* output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
__device__ __forceinline__ void WriteFpropResults(int classes,
scalar_t* input,
outscalar_t* output,
Epilogue<scalar_t, accum_t, outscalar_t> epilogue) {
int offset = threadIdx.x;

int last = classes % (ILP * blockDim.x);
Expand All @@ -264,8 +273,8 @@ WriteFpropResults(

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue>
__global__ void
softmax_block_forward(outscalar_t* output, scalar_t* input, int classes, int input_stride, int output_stride) {
__global__ void softmax_block_forward(outscalar_t* output, scalar_t* input, int classes,
int input_stride, int output_stride) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);

Expand Down

0 comments on commit c557a36

Please sign in to comment.