diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 32df249f362a0..18f077d6c127d 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -36,8 +36,47 @@ #include "gsl/gsl_algorithm" #include "gsl/gsl_util" +#if defined(_OPENMP) +#include +#endif + namespace onnxruntime { +common::Status SoftmaxCore(const int n, + const int d, + const float* Xdata, + float* Ydata, + const float* sum_multiplier, + float* rowmax) { + const int nd = n * d; + + math::RowwiseMax(n, d, Xdata, rowmax, nullptr); + // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry + gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); + math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + // Exponentiation + math::Exp(nd, Ydata, Ydata, nullptr); + return Status::OK(); +} + +static int GetParallelGroupCount(int n, int d) { +#if defined(_OPENMP) + int omp_num_threads = omp_get_num_threads(); + int group_count = std::min(omp_num_threads, n); + if (group_count <= 1) return 1; + + // 2048 * sizeof(float) is size of 2 cache page + static const int min_elements_per_group = 2048; + int max_groups = gsl::narrow_cast((int64_t{n} * d + min_elements_per_group-1) / min_elements_per_group); + + return std::min(group_count, max_groups); +#else + (void)n; + (void)d; + return 1; +#endif +} + common::Status SoftmaxCPU(const int64_t N, const int64_t D, const float* Xdata, @@ -57,21 +96,24 @@ common::Status SoftmaxCPU(const int64_t N, const int n = gsl::narrow_cast(N); const int d = gsl::narrow_cast(D); - const int nd = gsl::narrow_cast(N * D); - math::RowwiseMax(n, d, Xdata, rowmax, nullptr); - - // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry - gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); + int parallel_group_count = GetParallelGroupCount(n, d); + int n_per_group = (n + (parallel_group_count-1)) / parallel_group_count; - math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + #pragma omp parallel for + for (int i = 0; i < parallel_group_count; ++i) { + int s = n_per_group * i; + if (s < n) { + int c = (n - s >= n_per_group) ? n_per_group : (n-s); + SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), sum_multiplier, rowmax+s); + } + } - // Exponentiation - math::Exp(nd, Ydata, Ydata, nullptr); math::Gemv(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr); // Do division if (!logarithmic) { + #pragma omp parallel for for (int i = 0; i < N; ++i) { for (int j = 0; j < D; ++j) { Ydata[i * D + j] /= scale[i];