Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-optimised MatMulInteger implementation #356

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class OperatorType(object):
DequantizeLinear = 105
QuantizeLinear = 106
DynamicQuantizeLinear = 107
MatMulInteger = 108


class RNNDirection(object):
Expand Down
2 changes: 2 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub enum OpType<'a> {
Log,
LogSoftmax(LogSoftmax),
MatMul,
MatMulInteger,
Max,
MaxPool(MaxPool),
Mean,
Expand Down Expand Up @@ -614,6 +615,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
}
),
OpType::MatMul => op!(MatMul),
OpType::MatMulInteger => op!(MatMulInteger),
OpType::Max => op!(Max),
OpType::MaxPool(args) => op_with_attrs!(MaxPool, MaxPoolAttrs, {
let pad_args = pad_args_from_padding(args.padding);
Expand Down
2 changes: 2 additions & 0 deletions src/op_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ impl OpRegistry {
register_op!(LogSoftmax);
register_op!(LSTM);
register_op!(MatMul);
register_op!(MatMulInteger);
register_op!(Max);
register_op!(MaxPool);
register_op!(Mean);
Expand Down Expand Up @@ -610,6 +611,7 @@ impl_read_op!(LSTM, attrs_as_lstmattrs, |attrs: sg::LSTMAttrs| {
})
});
impl_read_op!(MatMul);
impl_read_op!(MatMulInteger);
impl_read_op!(Max);
impl_read_op!(
MaxPool,
Expand Down
295 changes: 294 additions & 1 deletion src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rten_tensor::{Tensor, TensorView};

use crate::check_dims;
use crate::gemm::{GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT};
use crate::iter_util::range_chunks;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::layout::expand_to;
use crate::ops::{InputList, IntoOpResult, OpError, Operator, OutputList};
Expand Down Expand Up @@ -266,6 +267,126 @@ impl Operator for MatMul {
}
}

pub fn matmul_integer(
pool: &TensorPool,
a: TensorView<u8>,
b: TensorView<i8>,
a_zero_point: Option<TensorView<u8>>,
b_zero_point: Option<TensorView<i8>>,
) -> Result<Tensor<i32>, OpError> {
if a.ndim() < 2 || b.ndim() < 2 {
return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions"));
}

let a_rows = a.size(a.ndim() - 2);
let a_cols = a.size(a.ndim() - 1);

let b_rows = b.size(b.ndim() - 2);
let b_cols = b.size(b.ndim() - 1);

if a_cols != b_rows {
return Err(OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
));
}

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];

let out_prefix = broadcast_shapes(a_prefix, b_prefix)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?;
let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat();

let mut output = Tensor::<i32>::uninit_in(pool, out_shape);
if output.is_empty() {
// nb. We don't need to alloc from the pool here, since the buffer
// is already empty.
return Ok(Tensor::zeros(out_shape));
}

let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat();
let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat();

let a_broadcast = a.broadcast(a_broadcast_shape.as_slice());
let b_broadcast = b.broadcast(b_broadcast_shape.as_slice());

fn is_scalar<T>(tensor: &Option<TensorView<T>>) -> bool {
tensor.as_ref().map(|zp| zp.ndim() == 0).unwrap_or(true)
}

if !is_scalar(&a_zero_point) || !is_scalar(&b_zero_point) {
return Err(OpError::UnsupportedValue(
"Only scalar zero points are supported",
));
}

let a_zero = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;
let b_zero = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;

a_broadcast
.inner_iter::<2>()
.zip(b_broadcast.inner_iter::<2>())
.zip(output.inner_iter_mut::<2>())
.par_bridge()
.for_each(|((a_mat, b_mat), mut out_mat)| {
let [m, k] = a_mat.shape();
let [bk, n] = b_mat.shape();
assert_eq!(k, bk);
assert_eq!(out_mat.shape(), [m, n]);

// Do some extremely rudimentary cache blocking.
for col_block in range_chunks(0..n, 32) {
for depth_block in range_chunks(0..k, 32) {
for row_block in range_chunks(0..m, 32) {
for j in col_block.clone() {
for i in row_block.clone() {
let mut out = 0i32;
for k in depth_block.clone() {
// Safety: `[i, k]` is in-bounds for `a_mat`.
let a = unsafe { *a_mat.get_unchecked([i, k]) } as i32 - a_zero;
// Safety: `[k, j]` is in-bounds for `b_mat`.
let b = unsafe { *b_mat.get_unchecked([k, j]) } as i32 - b_zero;
out += a * b;
}
unsafe {
// Safety: `[i, j]` is in-bounds for `b_mat`.
let el = out_mat.get_unchecked_mut([i, j]);
if depth_block.start == 0 {
el.write(out);
} else {
el.write(el.assume_init() + out);
}
}
}
}
}
}
}
});

// Safety: Loop above initialized all output elements.
let output = unsafe { output.assume_init() };

Ok(output)
}

#[derive(Debug)]
pub struct MatMulInteger {}

impl Operator for MatMulInteger {
fn name(&self) -> &str {
"MatMulInteger"
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
let a_zero_point = inputs.get_as(2)?;
let b_zero_point = inputs.get_as(3)?;
matmul_integer(pool, a, b, a_zero_point, b_zero_point).into_op_result()
}
}

#[cfg(test)]
mod tests {
use std::error::Error;
Expand All @@ -277,10 +398,11 @@ mod tests {
use rten_tensor::{Tensor, TensorView, TensorViewMut};

use crate::gemm::gemm;
use crate::ops::binary_elementwise::broadcast_shapes;
use crate::ops::tests::new_pool;
use crate::tensor_pool::AutoReturn;

use super::{gemm_op, matmul, matmul_impl, MatmulStrategy, OpError};
use super::{gemm_op, matmul, matmul_impl, matmul_integer, MatmulStrategy, OpError};

fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) {
c.make_contiguous();
Expand Down Expand Up @@ -323,6 +445,53 @@ mod tests {
});
}

fn reference_matmul_integer(
a: TensorView<u8>,
b: TensorView<i8>,
a_zero_point: Option<TensorView<u8>>,
b_zero_point: Option<TensorView<i8>>,
) -> Tensor<i32> {
let a_batch_dims = a.ndim() - 2;
let b_batch_dims = b.ndim() - 2;

let a_prefix = &a.shape()[..a.ndim() - 2];
let b_prefix = &b.shape()[..b.ndim() - 2];
let out_prefix = broadcast_shapes(a_prefix, b_prefix).unwrap();
let mut out_shape = out_prefix.to_vec();
out_shape.push(a.size(a.ndim() - 2));
out_shape.push(b.size(b.ndim() - 1));
let mut out = Tensor::<i32>::zeros(&out_shape);

let a_bcast = [out_prefix.as_slice(), &a.shape()[a_batch_dims..]].concat();
let b_bcast = [out_prefix.as_slice(), &b.shape()[b_batch_dims..]].concat();

let a_zero_point = a_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;
let b_zero_point = b_zero_point.and_then(|zp| zp.item()).copied().unwrap_or(0) as i32;

a.broadcast(a_bcast.as_slice())
.inner_iter::<2>()
.zip(b.broadcast(b_bcast.as_slice()).inner_iter::<2>())
.zip(out.inner_iter_mut::<2>())
.for_each(|((a, b), mut c)| {
let [n_rows, n_cols] = c.shape();
let depth = a.size(1);

for i in 0..n_rows {
for j in 0..n_cols {
let mut y = 0;
for k in 0..depth {
let a_el = (a[[i, k]] as i32) - a_zero_point;
let b_el = (b[[k, j]] as i32) - b_zero_point;
y += a_el * b_el;
}
c[[i, j]] = y;
}
}
});

out
}

#[test]
fn test_gemm_op() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
Expand Down Expand Up @@ -573,6 +742,130 @@ mod tests {
}
}

#[test]
fn test_matmul_integer() -> Result<(), Box<dyn Error>> {
struct Case {
a: Tensor<u8>,
b: Tensor<i8>,
a_zero_point: Option<Tensor<u8>>,
b_zero_point: Option<Tensor<i8>>,
expected_err: Option<OpError>,
}

let cases = [
// No zero point
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: None,
b_zero_point: None,
expected_err: None,
},
// Scalar zero points
Case {
a: Tensor::from([[1, 2], [3, 4]]),
b: Tensor::from([[5, 6], [7, 8]]),
a_zero_point: Some(Tensor::from(127)),
b_zero_point: Some(Tensor::from(-50)),
expected_err: None,
},
// Non-scalar zero points
Case {
a: Tensor::from([[2, 2], [2, 2]]),
b: Tensor::from([[2, 2], [2, 2]]),
a_zero_point: Some(Tensor::from([[2, 2], [2, 2]])),
b_zero_point: None,
expected_err: Some(OpError::UnsupportedValue(
"Only scalar zero points are supported",
)),
},
Case {
a: Tensor::from([[2, 2], [2, 2]]),
b: Tensor::from([[2, 2], [2, 2]]),
a_zero_point: None,
b_zero_point: Some(Tensor::from([[2, 2], [2, 2]])),
expected_err: Some(OpError::UnsupportedValue(
"Only scalar zero points are supported",
)),
},
// Empty output
Case {
a: Tensor::zeros(&[0, 2]),
b: Tensor::zeros(&[2, 3]),
a_zero_point: None,
b_zero_point: None,
expected_err: None,
},
// Mismatched shapes
Case {
a: Tensor::zeros(&[1, 2]),
b: Tensor::zeros(&[3, 1]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::IncompatibleInputShapes(
"Columns of first matrix does not match rows of second matrix",
)),
},
Case {
a: Tensor::zeros(&[1]),
b: Tensor::zeros(&[3, 1]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::InvalidValue("Inputs must have >= 2 dimensions")),
},
Case {
a: Tensor::zeros(&[1, 2]),
b: Tensor::zeros(&[1]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::InvalidValue("Inputs must have >= 2 dimensions")),
},
Case {
a: Tensor::zeros(&[2, 2, 2]),
b: Tensor::zeros(&[3, 2, 2]),
a_zero_point: None,
b_zero_point: None,
expected_err: Some(OpError::IncompatibleInputShapes("Cannot broadcast shapes")),
},
];

let pool = new_pool();

for Case {
a,
b,
a_zero_point,
b_zero_point,
expected_err,
} in cases
{
let result = matmul_integer(
&pool,
a.view(),
b.view(),
a_zero_point.as_ref().map(|zp| zp.view()),
b_zero_point.as_ref().map(|zp| zp.view()),
);

match (result, expected_err) {
(Ok(result), None) => {
let expected = reference_matmul_integer(
a.view(),
b.view(),
a_zero_point.as_ref().map(|zp| zp.view()),
b_zero_point.as_ref().map(|zp| zp.view()),
);
assert_eq!(result, expected);
}
(result, expected_err) => {
assert_eq!(result.err(), expected_err);
}
}
}

Ok(())
}

#[test]
#[ignore]
fn bench_matmul() {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub use layout::{
expand, flatten, reshape, squeeze, squeeze_in_place, Expand, Flatten, Reshape, Shape, Size,
Squeeze, Transpose, Unsqueeze,
};
pub use matmul::{gemm_op, matmul, Gemm, MatMul};
pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulInteger};
pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression};
pub use norm::{
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
Expand Down
1 change: 1 addition & 0 deletions src/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum OperatorType: ubyte {
DequantizeLinear,
QuantizeLinear,
DynamicQuantizeLinear,
MatMulInteger,
}

enum RNNDirection: ubyte {
Expand Down
Loading
Loading