Skip to content

Commit

Permalink
Optimize aarch64 GEMM kernel
Browse files Browse the repository at this point in the history
Revise aarch64 kernel to use SIMD intrinsics. The structure is the same as the
AVX 2 / FMA kernel, but the tile size is set to 8x8 as that performed best.

On an M1 Mac performance for an M=N=K=1024 matmul increases from ~334 to ~418
GFLOPS.
  • Loading branch information
robertknight committed Jan 5, 2024
1 parent 3701560 commit bc57ee8
Showing 1 changed file with 36 additions and 27 deletions.
63 changes: 36 additions & 27 deletions src/gemm/kernels/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,12 @@ use super::Kernel;

use crate::iter_util::unroll_loop;

/// This is not a fully optimized ARM NEON kernel, just an initial version
/// which is a copy of the base kernel that has been tweaked to:
///
/// - Use a larger tile size
/// - Use FMA instructions via `f32::mul_add`
/// - Unroll the inner loop
#[derive(Default)]
pub struct ArmNeonKernel {}

impl Kernel for ArmNeonKernel {
// ARM NEON has 32 registers. Empirically 14x4 is the largest tile size
// this naive auto-vectorized implementation can use before LLVM spills
// registers and performance drops. Better kernels in eg. OpenBLAS have
// 64-element tiles (8x8 or 16x4).

const MR: usize = 14;
const NR: usize = 4;
const MR: usize = 8;
const NR: usize = 8;

fn name() -> &'static str {
"arm-neon"
Expand All @@ -37,48 +26,68 @@ impl Kernel for ArmNeonKernel {
alpha: f32,
beta: f32,
) {
use std::arch::aarch64::{
vaddq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vmulq_f32, vst1q_f32,
};

const MR: usize = ArmNeonKernel::MR;
const NR: usize = ArmNeonKernel::NR;
const REG_SIZE: usize = 4;
const NR_REGS: usize = NR / REG_SIZE;

assert!(a.len() >= depth * MR);
assert!(b.len() >= depth * NR);

let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();

// Accumulate into a fixed-sized array to allow the compiler to generate
// more efficient code for the loop over `depth`.
let mut tmp = [[0.0; NR]; MR];
let mut tmp = [[vdupq_n_f32(0.); NR_REGS]; MR];
let mut b_rows = [vdupq_n_f32(0.); NR_REGS];

unroll_loop!(depth, k, 8, {
let a_off = k * MR;
let b_off = k * NR;

for i in 0..NR_REGS {
b_rows[i] = vld1q_f32(b_ptr.add(b_off + i * REG_SIZE));
}

for i in 0..MR {
for j in 0..NR {
tmp[i][j] = a
.get_unchecked(a_off + i)
.mul_add(*b.get_unchecked(b_off + j), tmp[i][j]);
let a_val = *a_ptr.add(a_off + i);
let a_broadcast = vdupq_n_f32(a_val);

for j in 0..NR_REGS {
tmp[i][j] = vfmaq_f32(tmp[i][j], a_broadcast, b_rows[j]);
}
}
});

if beta == 0. && alpha == 1. {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el = tmp[i][j];
for j in 0..NR_REGS {
let out_ptr = tile_ptr.add(tile_row_stride * i + j * REG_SIZE);
vst1q_f32(out_ptr, tmp[i][j]);
}
}
} else if beta == 1. && alpha == 1. {
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el += tmp[i][j];
for j in 0..NR_REGS {
let out_ptr = tile_ptr.add(tile_row_stride * i + j * REG_SIZE);
let out_val = vaddq_f32(vld1q_f32(out_ptr), tmp[i][j]);
vst1q_f32(out_ptr, out_val);
}
}
} else {
let alpha_broadcast = vdupq_n_f32(alpha);
let beta_broadcast = vdupq_n_f32(beta);
for i in 0..MR {
for j in 0..NR {
let out_el = tile_ptr.add(tile_row_stride * i + j);
*out_el = beta * *out_el + alpha * tmp[i][j];
for j in 0..NR_REGS {
let out_ptr = tile_ptr.add(tile_row_stride * i + j * REG_SIZE);
let out_val = vmulq_f32(vld1q_f32(out_ptr), beta_broadcast);
let out_val = vfmaq_f32(out_val, tmp[i][j], alpha_broadcast);
vst1q_f32(out_ptr, out_val);
}
}
}
Expand Down

0 comments on commit bc57ee8

Please sign in to comment.