diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 9ec7e849c0cb6..ccef01924b3ef 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -242,8 +242,7 @@ Status Attention::Compute(OpKernelContext* context) const { // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - // And for convenience, force max to 0.f if all xi are negative - float max = 0.f; + float max = -std::numeric_limits::infinity(); for (int i = 0; i < D; i++) { if (max < x[i]) max = x[i]; } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 5c72a1d53e316..26aee9affb46f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "attention_impl.h" @@ -61,22 +62,21 @@ __device__ inline void Softmax(const int ld, const int num_valid, const T* input __shared__ float sum_reverse_block; __shared__ float max_block; - float thread_data(0); + float thread_data_max(-CUDART_INF_F); + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld; for (int i = threadIdx.x; i < num_valid; i += TPB) { const int index = offset + i; - if (thread_data < float(input[index])) { - thread_data = float(input[index]); + if (thread_data_max < float(input[index])) { + thread_data_max = float(input[index]); } } - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - // And for convenience, force max to 0.f if all xi are negative - const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max()); + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max()); // Store max value if (threadIdx.x == 0) { @@ -84,13 +84,14 @@ __device__ inline void Softmax(const int ld, const int num_valid, const T* input } __syncthreads(); + float thread_data_sum(0.f); for (int i = threadIdx.x; i < num_valid; i += TPB) { const int index = offset + i; const float val = input[index]; - thread_data += expf(val - max_block); + thread_data_sum += expf(val - max_block); } - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum()); + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum()); if (threadIdx.x == 0) { sum_reverse_block = 1.f / sum; } @@ -114,17 +115,16 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld; const int index = offset + threadIdx.x; - float thread_data(0); - if (threadIdx.x < num_valid) { - thread_data = input[index]; - } - // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - // And for convenience, force max to 0.f if all xi are negative - const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), num_valid); + float thread_data_max(-CUDART_INF_F); + if (threadIdx.x < num_valid) { + thread_data_max = input[index]; + } + + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), num_valid); // Store max value if (threadIdx.x == 0) { @@ -132,12 +132,13 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* } __syncthreads(); + float thread_data_exp(0.f); if (threadIdx.x < num_valid) { const float val = input[index]; - thread_data = expf(val - max_block); + thread_data_exp = expf(val - max_block); } - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum(), num_valid); + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), num_valid); // Store max value if (threadIdx.x == 0) { @@ -147,7 +148,7 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* if (threadIdx.x < ld) { // this will be 0 for threadIdx.x >= num_valid - output[index] = T(thread_data * sum_reverse_block); + output[index] = T(thread_data_exp * sum_reverse_block); } }