Skip to content

Commit

Permalink
Merge pull request #564 from robertknight/arm-vector-vector-test
Browse files Browse the repository at this point in the history
Fix I8 -> U8 conversion in int8 gemv transposed case
  • Loading branch information
robertknight authored Jan 31, 2025
2 parents f4b954a + 1ec0353 commit 4424cb8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
30 changes: 24 additions & 6 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,9 @@ mod tests {
Case { m: 1, n: 5, k: 10 },
Case { m: 1, n: 8, k: 4 },
Case { m: 1, n: 16, k: 4 },
// Vector-matrix product, K not a multiple of 4 (tile size used by
// int8 dot product instructions).
Case { m: 1, n: 1, k: 2 },
// Vector-matrix, where n is large enough that work should be
// divided into multiple blocks.
Case {
Expand Down Expand Up @@ -1945,17 +1948,32 @@ mod tests {

#[test]
fn test_gemv_u8i8_i32_transposed() -> Result<(), Box<dyn Error>> {
struct Case {
n: usize,
k: usize,
}

let cases = [
// K multiple of 4
Case { k: 8, n: 5 },
// K not a multiple of 4
Case { k: 2, n: 5 },
];

for gemm in all_gemms::<u8, i8, i32>() {
let mut lhs_rng = XorShiftRng::new(1234);
let mut rhs_rng = ReducedRangeRng::new(gemm.may_saturate());
let a = NdTensor::<u8, 2>::rand([1, 8], &mut lhs_rng);
let mut b = NdTensor::<i8, 2>::from_simple_fn([5, 8], || rhs_rng.next_i8());

// Transpose the input B matrix. This will alter the row and column
// strides and shapes, but not re-order the data.
b.permute([1, 0]);
for &Case { k, n } in &cases {
let a = NdTensor::<u8, 2>::rand([1, k], &mut lhs_rng);
let mut b = NdTensor::<i8, 2>::from_simple_fn([n, k], || rhs_rng.next_i8());

run_compare_matmul(a.view(), b.view(), None, Some(&gemm));
// Transpose the input B matrix. This will alter the row and column
// strides and shapes, but not re-order the data.
b.permute([1, 0]);

run_compare_matmul(a.view(), b.view(), None, Some(&gemm));
}
}

Ok(())
Expand Down
3 changes: 2 additions & 1 deletion src/gemm/kernels/simd_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,15 @@ unsafe fn simd_int8_gemv_transposed<S: SimdInt, const CAST_B_U8: bool>(
for k in k_tiles.remainder() {
let a = *a_ptr.add(k) as i32;
let b = *b_ptr.add(k) as i32;
let b = if CAST_B_U8 { b + b_zero_shift } else { b };
acc += a * b;
col_sum += b;
}

let a_zero = a_zero_point as i32;
let b_zero = b_zero_points
.map(|bz| bz[col] as i32 + b_zero_shift)
.unwrap_or(0);
.unwrap_or(b_zero_shift);
let acc = (depth as i32 * a_zero * b_zero) + acc - row_sum * b_zero - col_sum * a_zero;

let out_ptr = out.get_unchecked_mut(col);
Expand Down

0 comments on commit 4424cb8

Please sign in to comment.