Skip to content

Commit

Permalink
Merge pull request #418 from robertknight/matmul-vec
Browse files Browse the repository at this point in the history
Support vector inputs in MatMul operator
  • Loading branch information
robertknight authored Nov 29, 2024
2 parents def4197 + 990ce22 commit 2bfd667
Showing 1 changed file with 79 additions and 10 deletions.
89 changes: 79 additions & 10 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,26 @@ where

fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pool: &TensorPool,
a: TensorView<LhsT>,
b: TensorView<RhsT>,
mut a: TensorView<LhsT>,
mut b: TensorView<RhsT>,
strategy: MatmulStrategy,
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
if a.ndim() < 1 || b.ndim() < 1 {
return Err(OpError::InvalidValue("Inputs must have >= 1 dimensions"));
}

// Expand vector inputs to matrices. This follows the rules of `numpy.matmul`.
// See https://numpy.org/doc/stable/reference/generated/numpy.matmul.html.
let a_is_vec = a.ndim() == 1;
if a_is_vec {
a.insert_axis(0);
}
let b_is_vec = b.ndim() == 1;
if b_is_vec {
b.insert_axis(1);
}

let a_rows = a.size(a.ndim() - 2);
Expand Down Expand Up @@ -246,7 +257,14 @@ where
});

// Safety: Loop above initialized all output elements.
let output = unsafe { output.assume_init() };
let mut output = unsafe { output.assume_init() };

if a_is_vec {
output.remove_axis(output.ndim() - 2);
}
if b_is_vec {
output.remove_axis(output.ndim() - 1);
}

Ok(output)
}
Expand Down Expand Up @@ -419,7 +437,30 @@ mod tests {
/// Multiply matrices in `a` by corresponding matrices in `b` and write to
/// `c`. The shapes of `a` and `b` are broadcast so that their first N-2
/// dims match `c`.
fn reference_matmul(mut c: TensorViewMut, a: TensorView, b: TensorView) {
fn reference_matmul(mut c: TensorViewMut, mut a: TensorView, mut b: TensorView) {
// Expand vector inputs to matrices. This follows the rules of
// `numpy.matmul`.
let a_is_vec = a.ndim() == 1;
if a_is_vec {
a.insert_axis(0);
}
let b_is_vec = b.ndim() == 1;
if b_is_vec {
b.insert_axis(1);
}

// If one or both of the inputs are vectors, temporarily expand the
// output shape to match the expanded input shapes.
match (a_is_vec, b_is_vec) {
(true, false) => c.insert_axis(c.ndim() - 1),
(false, true) => c.insert_axis(c.ndim()),
(true, true) => {
c.insert_axis(c.ndim());
c.insert_axis(c.ndim());
}
(false, false) => {}
}

let a_batch_dims = a.ndim() - 2;
let b_batch_dims = b.ndim() - 2;
let out_prefix = &c.shape()[..c.ndim() - 2];
Expand All @@ -442,6 +483,16 @@ mod tests {
0., /* beta */
)
});

match (a_is_vec, b_is_vec) {
(true, false) => c.remove_axis(c.ndim() - 2),
(false, true) => c.remove_axis(c.ndim() - 1),
(true, true) => {
c.remove_axis(c.ndim() - 1);
c.remove_axis(c.ndim() - 1);
}
(false, false) => {}
}
}

fn reference_matmul_integer(
Expand Down Expand Up @@ -639,6 +690,24 @@ mod tests {
b_shape: &[2, 10, 8],
out_shape: &[2, 3, 8],
},
// LHS is a vector
Case {
a_shape: &[4],
b_shape: &[4, 8],
out_shape: &[8],
},
// RHS is a vector
Case {
a_shape: &[4, 6],
b_shape: &[6],
out_shape: &[4],
},
// LHS and RHS are both vectors
Case {
a_shape: &[4],
b_shape: &[4],
out_shape: &[],
},
];

let pool = new_pool();
Expand Down Expand Up @@ -672,14 +741,14 @@ mod tests {

let cases = [
Case {
a_shape: &[3],
a_shape: &[],
b_shape: &[10, 8],
error: OpError::InvalidValue("Inputs must have >= 2 dimensions"),
error: OpError::InvalidValue("Inputs must have >= 1 dimensions"),
},
Case {
a_shape: &[3, 10],
b_shape: &[10],
error: OpError::InvalidValue("Inputs must have >= 2 dimensions"),
b_shape: &[],
error: OpError::InvalidValue("Inputs must have >= 1 dimensions"),
},
Case {
a_shape: &[3, 10],
Expand Down

0 comments on commit 2bfd667

Please sign in to comment.