Skip to content

Commit

Permalink
Merge pull request #518 from robertknight/avx512-gemm-tweaks
Browse files Browse the repository at this point in the history
Use a more aggressive approach to unrolling in `simd_gemm`
  • Loading branch information
robertknight authored Jan 5, 2025
2 parents d85c66b + 12a4907 commit 2cd82d3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/gemm/kernels/simd_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rten_simd::SimdFloat;
use rten_tensor::{Matrix, MatrixLayout, Storage};

use super::Lhs;
use crate::iter_util::{range_chunks_exact, unroll_loop};
use crate::iter_util::{range_chunks_exact, unroll_loop, unroll_loop_x4};

/// Compute an output block of a vector-matrix product ("gemv" in BLAS APIs).
///
Expand Down Expand Up @@ -293,7 +293,7 @@ pub unsafe fn simd_gemm<S: SimdFloat, const MR: usize, const NR_REGS: usize, con
let mut tmp = [[S::zero(); NR_REGS]; ROWS];
let mut b_rows = [S::zero(); NR_REGS];

unroll_loop!(0..depth - 1, k, 4, {
unroll_loop_x4!(0..depth - 1, k, {
let b_off = k * NR_REGS * S::LEN;

// Prefetch B for the next iteration
Expand Down
39 changes: 38 additions & 1 deletion src/iter_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,49 @@ impl MaybeParIter for Range<usize> {
}
}

/// Unroll a loop 4x.
///
/// This is very similar to [`unroll_loop`] but uses a more aggressive approach
/// to unrolling which only supports a fixed unroll factor. Whereas
/// `unroll_loop` uses a hint (a `for` loop with a fixed iteration count) which
/// the compiler follows most of the time, this macro actually duplicates the
/// body 4x.
macro_rules! unroll_loop_x4 {
($range:expr, $loop_var:ident, $block:tt) => {
let mut n = $range.len();
let mut $loop_var = $range.start;

while n >= 4 {
$block;
$loop_var += 1;
$block;
$loop_var += 1;
$block;
$loop_var += 1;
$block;
$loop_var += 1;
n -= 4;
}

while n > 0 {
$block;
$loop_var += 1;
n -= 1;
}
};
}

/// Generate an unrolled loop.
///
/// `$range` is a `Range` specifying the loop start and end. `$loop_var` is the
/// name of the variable containing the current iteration inside `$block`.
/// `$factor` should be a constant expression specifying the unroll factor,
/// typically a small value such as 4 or 8.
///
/// This macro generates a "hint" in the form of a `for` loop with a const
/// iteration count which the compiler follows in most cases. If it doesn't,
/// and you're sure you still need unrolling, consider [`unroll_loop_x4`]
/// instead.
macro_rules! unroll_loop {
($range:expr, $loop_var:ident, $factor: expr, $block:tt) => {
let mut n = $range.len();
Expand All @@ -161,7 +198,7 @@ macro_rules! unroll_loop {
}

#[allow(unused_imports)]
pub(crate) use unroll_loop;
pub(crate) use {unroll_loop, unroll_loop_x4};

#[cfg(test)]
mod tests {
Expand Down

0 comments on commit 2cd82d3

Please sign in to comment.