Skip to content

Commit

Permalink
Use at::Vectorized in optimized log_softmax
Browse files Browse the repository at this point in the history
Pull Request resolved: #8382

This should allow us to enable this op in OSS, because Vectorized handles any Sleef issues for us as needed. (I considered going straight to sharing the PyTorch core implementation, but we need parallel_for enabled for that and this improvement is easy enough to make.)

Differential Revision: [D69473208](https://our.internmc.facebook.com/intern/diff/D69473208/)
ghstack-source-id: 267044107

Co-authored-by: Github Executorch <[email protected]>
  • Loading branch information
pytorchbot and Github Executorch authored Feb 19, 2025
1 parent 8d1480b commit cc3974f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
32 changes: 17 additions & 15 deletions kernels/optimized/cpu/op_log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <cmath>
#include <type_traits>

#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

Expand Down Expand Up @@ -66,30 +68,30 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
}
// calculate sum and exponential in softmax dim
OUT_T temp_sum = 0;
#ifndef __aarch64__
for (auto d = 0; d < dim_size; ++d) {
output_data[d * dim_stride] =
std::exp(input_data[d * dim_stride] - max_input);
temp_sum += output_data[d * dim_stride];
}
#else
using VecOut = at::vec::Vectorized<OUT_T>;
using VecIn = at::vec::Vectorized<IN_T>;
auto d = 0;
for (; d + 4 < dim_size; d += 4) {
static_assert(sizeof(IN_T) == sizeof(OUT_T));
static_assert(
std::is_same_v<OUT_T, float>,
"Below loop actually only supports float.");
const VecIn max_input_vec(max_input);
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
auto index = d * dim_stride;
float32x4_t in =
vld1q_f32(static_cast<const float*>(&input_data[index]));
float32x4_t out_ =
Sleef_expf4_u10(vsubq_f32(in, vmovq_n_f32(max_input)));
vst1q_f32(static_cast<float*>(&output_data[index]), out_);
auto in = VecIn::loadu(&input_data[index]);
auto out_ = (in - max_input_vec).exp();
out_.store(&output_data[index]);
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
temp_sum += vaddvq_f32(out_);
#else
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
#endif
}

for (; d < dim_size; ++d) {
output_data[d * dim_stride] =
std::exp(input_data[d * dim_stride] - max_input);
temp_sum += output_data[d * dim_stride];
}
#endif // __aarch64__

temp_sum = std::log(temp_sum);

Expand Down
13 changes: 4 additions & 9 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,10 @@ _OPTIMIZED_ATEN_OPS = (
),
op_target(
name = "op_log_softmax",
deps = select({
"DEFAULT": [
"//executorch/kernels/portable/cpu/util:activation_ops_util",
],
"ovr_config//cpu:arm64": [
"//executorch/kernels/portable/cpu/util:activation_ops_util",
"fbsource//third-party/sleef:sleef_arm",
],
}),
deps = [
"//executorch/kernels/portable/cpu/util:activation_ops_util",
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
],
),
op_target(
name = "op_mm",
Expand Down

0 comments on commit cc3974f

Please sign in to comment.