Skip to content

Commit

Permalink
Optimize int8 GEMV with non-transposed B matrix
Browse files Browse the repository at this point in the history
Instead of loading 4 rows of 4 elements from B at a time, load 4 rows of 16
elements and interleave to give 4 x 4x4 transposed tiles.

The inner loop over K was also changed to manually increment `k` instead of
using `range_chunks_exact` as this generated slightly better code.

On an M3 Pro this improved the `bench_gemm_mix` gemv benchmark for int8
from ~45 GFLOPS to ~60 GFLOPS for the non-transposed case.
  • Loading branch information
robertknight committed Jan 29, 2025
1 parent c6da16e commit 523330a
Showing 1 changed file with 80 additions and 49 deletions.
129 changes: 80 additions & 49 deletions src/gemm/kernels/simd_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,71 +578,102 @@ pub unsafe fn simd_int8_gemv<S: SimdInt, const CAST_B_U8: bool>(

let row_sum: i32 = a.iter().map(|x| *x as i32).sum();

let mut col_tiles = range_chunks_exact(0..b.cols(), S::LEN);
// Iterate over `4 x S::LEN` columns at a time, which is one `i8` SIMD vec,
// four `i32` vecs.
let mut col_tiles = range_chunks_exact(0..b.cols(), 4 * S::LEN);
for col_tile in col_tiles.by_ref() {
let b_ptr = b_ptr.add(col_tile.start);
let mut acc = S::zero();
let mut col_sums = S::zero();
let mut acc = [S::zero(); 4];
let mut col_sums = [S::zero(); 4];
let one_u8 = S::splat(i32::from_le_bytes([1; 4]));

// Loop over K tiles of size 4.
let mut k_tiles = range_chunks_exact(0..depth, 4);
for k_tile in k_tiles.by_ref() {
// Broadcast 4x u8 values
let a = S::splat(*(a_ptr.add(k_tile.start) as *const i32));

// Load `S::LEN` groups of 4 i8 values.
let b = S::load_interleave_i8(
b_ptr.add(k_tile.start * b_row_stride),
b_ptr.add((k_tile.start + 1) * b_row_stride),
b_ptr.add((k_tile.start + 2) * b_row_stride),
b_ptr.add((k_tile.start + 3) * b_row_stride),
);
let b = if CAST_B_U8 { b.xor(bit_flip_mask) } else { b };

// Compute `C += dot(A, B)` for each of the `S::LEN` columns.
acc = dot_product(a, b, acc);
col_sums = dot_product(one_u8, b, col_sums);
let mut k = 0;
while k + 4 <= depth {
// Broadcast 4 values from A.
let a = S::splat(*(a_ptr.add(k) as *const i32));

// Load 4 rows of 16 int8 elements from B and interleave to give
// 4 transposed `[4, S::LEN]` tiles. eg. Given 4 rows A, B, C, D,
// if `S::LEN` = 4, the tiles are:
//
// tile 0: A0 B0 C0 D0 ... A3 B3 C3 D3
// tile 1: A4 B4 C4 D4 ... A7 B7 C7 D7
// tile 2: A8 B8 C8 D8 ... A11 B11 C11 D11
// tile 3: A12 B12 C12 D12 ... A15 B15 C15 D15
let b0 = S::load(b_ptr.add(k * b_row_stride) as *const i32);
let b1 = S::load(b_ptr.add((k + 1) * b_row_stride) as *const i32);
let b2 = S::load(b_ptr.add((k + 2) * b_row_stride) as *const i32);
let b3 = S::load(b_ptr.add((k + 3) * b_row_stride) as *const i32);

let b01_lo = b0.zip_lo_i8(b1);
let b01_hi = b0.zip_hi_i8(b1);
let b23_lo = b2.zip_lo_i8(b3);
let b23_hi = b2.zip_hi_i8(b3);

let b_tiles = [
b01_lo.zip_lo_i16(b23_lo),
b01_lo.zip_hi_i16(b23_lo),
b01_hi.zip_lo_i16(b23_hi),
b01_hi.zip_hi_i16(b23_hi),
];

for i in 0..4 {
let b_tile = if CAST_B_U8 {
b_tiles[i].xor(bit_flip_mask)
} else {
b_tiles[i]
};
acc[i] = dot_product(a, b_tile, acc[i]);
col_sums[i] = dot_product(one_u8, b_tile, col_sums[i]);
}
k += 4;
}

for k in k_tiles.remainder() {
let a = S::splat(*a_ptr.add(k) as i32);
let b = S::load_extend_i8(b_ptr.add(k * b_row_stride));
let b = if CAST_B_U8 {
b.add(S::splat(b_zero_shift))
} else {
b
};
while k < depth {
let a = S::splat((*a_ptr.add(k)).into());

for i in 0..4 {
let b = S::load_extend_i8(b_ptr.add(k * b_row_stride + i * S::LEN));
let b = if CAST_B_U8 {
b.add(S::splat(b_zero_shift))
} else {
b
};

acc = a.mul(b).add(acc);
col_sums = col_sums.add(b);
acc[i] = a.mul(b).add(acc[i]);
col_sums[i] = col_sums[i].add(b);
}
k += 1;
}

// Subtract zero points. This is equivalent to doing
// `acc += (a - a_zero) * (b - b_zero)` in the loop over K, but more
// efficient.
let row_sum_vec = S::splat(row_sum);
let depth_vec = S::splat(depth as i32);
let a_zero_vec = S::splat(a_zero_point as i32);
let b_zero_vec = if let Some(b_zero) = b_zero_points {
S::load_extend_i8(b_zero.as_ptr().add(col_tile.start))
} else {
S::zero()
};
let b_zero_vec = b_zero_vec.add(S::splat(b_zero_shift));
let a_zero_vec = S::splat(a_zero_point.into());

acc = depth_vec
.mul(a_zero_vec)
.mul(b_zero_vec)
.add(acc)
.sub(row_sum_vec.mul(b_zero_vec))
.sub(col_sums.mul(a_zero_vec));

let out_ptr = out.as_ptr().add(col_tile.start) as *mut i32;
if !accumulate {
acc.store(out_ptr);
} else {
S::load(out_ptr).add(acc).store(out_ptr);
for i in 0..4 {
let b_zero_vec = if let Some(b_zero) = b_zero_points {
S::load_extend_i8(b_zero.as_ptr().add(col_tile.start + i * S::LEN))
} else {
S::zero()
};
let b_zero_vec = b_zero_vec.add(S::splat(b_zero_shift));
acc[i] = depth_vec
.mul(a_zero_vec)
.mul(b_zero_vec)
.add(acc[i])
.sub(row_sum_vec.mul(b_zero_vec))
.sub(col_sums[i].mul(a_zero_vec));

let out_ptr = out.as_ptr().add(col_tile.start + i * S::LEN) as *mut i32;
if !accumulate {
acc[i].store(out_ptr);
} else {
S::load(out_ptr).add(acc[i]).store(out_ptr);
}
}
}

Expand Down

0 comments on commit 523330a

Please sign in to comment.