diff --git a/src/gemm/kernels/simd_generic.rs b/src/gemm/kernels/simd_generic.rs index fe18b2f8..30f3b863 100644 --- a/src/gemm/kernels/simd_generic.rs +++ b/src/gemm/kernels/simd_generic.rs @@ -4,7 +4,7 @@ use rten_simd::SimdFloat; use rten_tensor::{Matrix, MatrixLayout, Storage}; use super::Lhs; -use crate::iter_util::{range_chunks_exact, unroll_loop}; +use crate::iter_util::{range_chunks_exact, unroll_loop, unroll_loop_x4}; /// Compute an output block of a vector-matrix product ("gemv" in BLAS APIs). /// @@ -293,7 +293,7 @@ pub unsafe fn simd_gemm { } } +/// Unroll a loop 4x. +/// +/// This is very similar to [`unroll_loop`] but uses a more aggressive approach +/// to unrolling which only supports a fixed unroll factor. Whereas +/// `unroll_loop` uses a hint (a `for` loop with a fixed iteration count) which +/// the compiler follows most of the time, this macro actually duplicates the +/// body 4x. +macro_rules! unroll_loop_x4 { + ($range:expr, $loop_var:ident, $block:tt) => { + let mut n = $range.len(); + let mut $loop_var = $range.start; + + while n >= 4 { + $block; + $loop_var += 1; + $block; + $loop_var += 1; + $block; + $loop_var += 1; + $block; + $loop_var += 1; + n -= 4; + } + + while n > 0 { + $block; + $loop_var += 1; + n -= 1; + } + }; +} + /// Generate an unrolled loop. /// /// `$range` is a `Range` specifying the loop start and end. `$loop_var` is the /// name of the variable containing the current iteration inside `$block`. /// `$factor` should be a constant expression specifying the unroll factor, /// typically a small value such as 4 or 8. +/// +/// This macro generates a "hint" in the form of a `for` loop with a const +/// iteration count which the compiler follows in most cases. If it doesn't, +/// and you're sure you still need unrolling, consider [`unroll_loop_x4`] +/// instead. macro_rules! unroll_loop { ($range:expr, $loop_var:ident, $factor: expr, $block:tt) => { let mut n = $range.len(); @@ -161,7 +198,7 @@ macro_rules! unroll_loop { } #[allow(unused_imports)] -pub(crate) use unroll_loop; +pub(crate) use {unroll_loop, unroll_loop_x4}; #[cfg(test)] mod tests {