Skip to content

Commit

Permalink
Initialize max of softmax with lowest of float (#2786)
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee authored Jan 9, 2020
1 parent 2c8179b commit 71b5165
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
// And for convenience, force max to 0.f if all xi are negative
float max = 0.f;
float max = -std::numeric_limits<float>::infinity();
for (int i = 0; i < D; i++) {
if (max < x[i]) max = x[i];
}
Expand Down
43 changes: 22 additions & 21 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "attention_impl.h"
Expand Down Expand Up @@ -61,36 +62,36 @@ __device__ inline void Softmax(const int ld, const int num_valid, const T* input
__shared__ float sum_reverse_block;
__shared__ float max_block;

float thread_data(0);
float thread_data_max(-CUDART_INF_F);

// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld;
for (int i = threadIdx.x; i < num_valid; i += TPB) {
const int index = offset + i;
if (thread_data < float(input[index])) {
thread_data = float(input[index]);
if (thread_data_max < float(input[index])) {
thread_data_max = float(input[index]);
}
}

// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
// And for convenience, force max to 0.f if all xi are negative
const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max());
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());

// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();

float thread_data_sum(0.f);
for (int i = threadIdx.x; i < num_valid; i += TPB) {
const int index = offset + i;
const float val = input[index];
thread_data += expf(val - max_block);
thread_data_sum += expf(val - max_block);
}

const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum());
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum());
if (threadIdx.x == 0) {
sum_reverse_block = 1.f / sum;
}
Expand All @@ -114,30 +115,30 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T*
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld;
const int index = offset + threadIdx.x;

float thread_data(0);
if (threadIdx.x < num_valid) {
thread_data = input[index];
}

// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
// And for convenience, force max to 0.f if all xi are negative
const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), num_valid);
float thread_data_max(-CUDART_INF_F);
if (threadIdx.x < num_valid) {
thread_data_max = input[index];
}

const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), num_valid);

// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();

float thread_data_exp(0.f);
if (threadIdx.x < num_valid) {
const float val = input[index];
thread_data = expf(val - max_block);
thread_data_exp = expf(val - max_block);
}

const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum(), num_valid);
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), num_valid);

// Store max value
if (threadIdx.x == 0) {
Expand All @@ -147,7 +148,7 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T*

if (threadIdx.x < ld) {
// this will be 0 for threadIdx.x >= num_valid
output[index] = T(thread_data * sum_reverse_block);
output[index] = T(thread_data_exp * sum_reverse_block);
}
}

Expand Down

0 comments on commit 71b5165

Please sign in to comment.