Skip to content

Commit

Permalink
Merge pull request #572 from robertknight/broadcast-matmul-zero-point
Browse files Browse the repository at this point in the history
Broadcast zero point vector when converting batched matmul to non-batched
  • Loading branch information
robertknight authored Feb 3, 2025
2 parents f5fcba9 + 81e2328 commit b5a6d1b
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ where
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous_in(pool).auto_return(pool);
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());

// Broadcast zero point to match new row count.
let a_quant: Option<Vec<LhsT>> = a_quant.map(|a_quant| {
a_quant
.zero_point
.iter()
.copied()
.cycle()
.take(a_matrix.size(0))
.collect()
});

let mut output = matmul_impl(
pool,
a_matrix.view(),
Expand All @@ -230,7 +242,9 @@ where
strategy,
bias,
alpha,
a_quant,
a_quant.as_ref().map(|zero_point| QuantParams {
zero_point: zero_point.as_slice(),
}),
b_quant,
)?;
output.reshape(out_shape);
Expand Down Expand Up @@ -1065,7 +1079,15 @@ mod tests {
b_zero_point: Some(Tensor::from([3, 4])),
expected_err: None,
},
// A input which is a row vector
// LHS batch input with vector zero point
Case {
a: Tensor::zeros(&[3, 2, 2]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: Some(Tensor::from([1, 2])),
b_zero_point: Some(Tensor::from([3, 4])),
expected_err: None,
},
// An input which is a row vector
Case {
a: Tensor::from([[1, 2, 3, 4]]),
b: Tensor::from([[5, 6], [7, 8], [9, 10], [11, 12]]),
Expand Down

0 comments on commit b5a6d1b

Please sign in to comment.