Skip to content

Commit

Permalink
blas: Simplify control flow in matrix multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Aug 9, 2024
1 parent 453eae3 commit 0153a37
Showing 1 changed file with 62 additions and 86 deletions.
148 changes: 62 additions & 86 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,32 +371,15 @@ where
#[cfg(not(feature = "blas"))]
use self::mat_mul_general as mat_mul_impl;

#[rustfmt::skip]
#[cfg(feature = "blas")]
fn mat_mul_impl<A>(
alpha: A,
a: &ArrayView2<'_, A>,
b: &ArrayView2<'_, A>,
beta: A,
c: &mut ArrayViewMut2<'_, A>,
) where
A: LinalgScalar,
fn mat_mul_impl<A>(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>)
where A: LinalgScalar
{
// size cutoff for using BLAS
let cut = GEMM_BLAS_CUTOFF;
let ((m, k), (k2, n)) = (a.dim(), b.dim());
debug_assert_eq!(k, k2);
if !(m > cut || n > cut || k > cut)
|| !(same_type::<A, f32>()
|| same_type::<A, f64>()
|| same_type::<A, c32>()
|| same_type::<A, c64>())
if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF)
&& (same_type::<A, f32>() || same_type::<A, f64>() || same_type::<A, c32>() || same_type::<A, c64>())
{
return mat_mul_general(alpha, a, b, beta, c);
}

#[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block
'blas_block: loop {
// Compute A B -> C
// We require for BLAS compatibility that:
// A, B, C are contiguous (stride=1) in their fastest dimension,
Expand All @@ -408,75 +391,68 @@ fn mat_mul_impl<A>(
// Apply transpose to A, B as needed if they differ from the row major case.
// If C is CblasColMajor then transpose both A, B (again!)

let (a_layout, b_layout, c_layout) =
if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
(get_blas_compatible_layout(a),
get_blas_compatible_layout(b),
get_blas_compatible_layout(c))
{
(a_layout, b_layout, c_layout)
} else {
break 'blas_block;
};

let cblas_layout = c_layout.to_cblas_layout();
let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
let lda = blas_stride(&a, a_layout);

let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
let ldb = blas_stride(&b, b_layout);

let ldc = blas_stride(&c, c_layout);

macro_rules! gemm_scalar_cast {
(f32, $var:ident) => {
cast_as(&$var)
};
(f64, $var:ident) => {
cast_as(&$var)
};
(c32, $var:ident) => {
&$var as *const A as *const _
};
(c64, $var:ident) => {
&$var as *const A as *const _
};
}
if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
(get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c))
{
let cblas_layout = c_layout.to_cblas_layout();
let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
let lda = blas_stride(&a, a_layout);

let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
let ldb = blas_stride(&b, b_layout);

let ldc = blas_stride(&c, c_layout);

macro_rules! gemm_scalar_cast {
(f32, $var:ident) => {
cast_as(&$var)
};
(f64, $var:ident) => {
cast_as(&$var)
};
(c32, $var:ident) => {
&$var as *const A as *const _
};
(c64, $var:ident) => {
&$var as *const A as *const _
};
}

macro_rules! gemm {
($ty:tt, $gemm:ident) => {
if same_type::<A, $ty>() {
// gemm is C ← αA^Op B^Op + βC
// Where Op is notrans/trans/conjtrans
unsafe {
blas_sys::$gemm(
cblas_layout,
a_trans,
b_trans,
m as blas_index, // m, rows of Op(a)
n as blas_index, // n, cols of Op(b)
k as blas_index, // k, cols of Op(a)
gemm_scalar_cast!($ty, alpha), // alpha
a.ptr.as_ptr() as *const _, // a
lda, // lda
b.ptr.as_ptr() as *const _, // b
ldb, // ldb
gemm_scalar_cast!($ty, beta), // beta
c.ptr.as_ptr() as *mut _, // c
ldc, // ldc
);
macro_rules! gemm {
($ty:tt, $gemm:ident) => {
if same_type::<A, $ty>() {
// gemm is C ← αA^Op B^Op + βC
// Where Op is notrans/trans/conjtrans
unsafe {
blas_sys::$gemm(
cblas_layout,
a_trans,
b_trans,
m as blas_index, // m, rows of Op(a)
n as blas_index, // n, cols of Op(b)
k as blas_index, // k, cols of Op(a)
gemm_scalar_cast!($ty, alpha), // alpha
a.ptr.as_ptr() as *const _, // a
lda, // lda
b.ptr.as_ptr() as *const _, // b
ldb, // ldb
gemm_scalar_cast!($ty, beta), // beta
c.ptr.as_ptr() as *mut _, // c
ldc, // ldc
);
}
return;
}
return;
}
};
}
};
}

gemm!(f32, cblas_sgemm);
gemm!(f64, cblas_dgemm);
gemm!(c32, cblas_cgemm);
gemm!(c64, cblas_zgemm);
gemm!(f32, cblas_sgemm);
gemm!(f64, cblas_dgemm);
gemm!(c32, cblas_cgemm);
gemm!(c64, cblas_zgemm);

break 'blas_block;
unreachable!() // we checked above that A is one of f32, f64, c32, c64
}
}
mat_mul_general(alpha, a, b, beta, c)
}
Expand Down

0 comments on commit 0153a37

Please sign in to comment.