Skip to content

Commit

Permalink
Merge pull request #70 from robertknight/gemm-pack-uninit-2
Browse files Browse the repository at this point in the history
Avoid zero-initialization of packing buffers
  • Loading branch information
robertknight authored Mar 30, 2024
2 parents 864b78a + 87c0171 commit ae2b8da
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 71 deletions.
90 changes: 39 additions & 51 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,46 +407,36 @@ impl GemmExecutor {

/// Prepack a matrix for use as the left-hand or "A" input.
pub fn prepack_a(&self, a: Matrix) -> PackedAMatrix<'static> {
let mut data = Vec::new();
let packed = self.prepack_a_into(a, &mut data);
let rows = packed.rows;
let cols = packed.cols;
let panel_len = packed.panel_len;
let row_blocks = packed.row_blocks;

PackedAMatrix {
data: Cow::Owned(data),
rows,
cols,
panel_len,
row_blocks,
}
}

/// Prepack a matrix for use as the left-hand or "A" input, re-using an
/// existing buffer which will be resized as needed.
fn prepack_a_into<'a>(&self, a: Matrix, out: &'a mut Vec<f32>) -> PackedAMatrix<'a> {
let kc = depth_block_size(a.cols());
let mc = row_block_size(a.rows(), self.mr);
let panel_len = kc * mc;
let row_blocks = div_ceil(a.rows(), mc);
let depth_blocks = div_ceil(a.cols(), kc);

out.resize(depth_blocks * row_blocks * panel_len, 0.);
let packed_len = depth_blocks * row_blocks * panel_len;
let mut data = Vec::with_capacity(packed_len);

// Pack blocks in the order they will be accessed by the GEMM
// implementation.
let mut out_panels = out.chunks_exact_mut(panel_len);
let mut out_panels = data.spare_capacity_mut()[..packed_len].chunks_exact_mut(panel_len);
let mut n_init = 0;
for depth_range in range_chunks(0..a.cols(), kc) {
for row_range in range_chunks(0..a.rows(), mc) {
let out_panel = out_panels.next().unwrap();
self.kernel
.pack_a_block(out_panel, a, row_range, depth_range.clone());
n_init += out_panel.len();
}
}

// Safety: We used `pack_a_block` to initialize `packed_len` elements.
assert!(n_init == packed_len);
unsafe {
data.set_len(packed_len);
}

PackedAMatrix {
data: Cow::Borrowed(out),
data: Cow::Owned(data),
rows: a.rows(),
cols: a.cols(),
panel_len,
Expand All @@ -462,19 +452,28 @@ impl GemmExecutor {
let depth_blocks = div_ceil(a_cols, kc);
let col_blocks = div_ceil(b.cols(), nc);

let mut out = vec![0.; col_blocks * depth_blocks * panel_len];
let packed_len = col_blocks * depth_blocks * panel_len;
let mut out = Vec::with_capacity(packed_len);

// Pack blocks in the order they will be accessed by the GEMM
// implementation.
let mut out_panels = out.chunks_exact_mut(panel_len);
let mut out_panels = out.spare_capacity_mut()[..packed_len].chunks_exact_mut(panel_len);
let mut n_init = 0;
for col_range in range_chunks(0..b.cols(), nc) {
for depth_range in range_chunks(0..a_cols, kc) {
let out_panel = out_panels.next().unwrap();
self.kernel
.pack_b_block(out_panel, b, depth_range, col_range.clone());
n_init += out_panel.len();
}
}

// Safety: We used `pack_b_block` to initialize `packed_len` elements.
assert!(n_init == packed_len);
unsafe {
out.set_len(packed_len);
}

PackedBMatrix {
data: out,
rows: b.rows(),
Expand Down Expand Up @@ -791,13 +790,19 @@ fn gemm_impl<K: Kernel, const MR_NR: usize>(
let packed_b = match b {
GemmInputB::Unpacked(b) => PACKED_B.with(|cell| {
let mut packed_b = cell.take();
packed_b.resize(packed_b_size, 0.);
packed_b.clear();
packed_b.reserve(packed_b_size);
pack_b_block::<K>(
&mut packed_b,
&mut packed_b.spare_capacity_mut()[..packed_b_size],
b,
depth_range.clone(),
col_start..col_end,
);
// Safety: pack_b_block initialized `packed_b_size`
// elements.
unsafe {
packed_b.set_len(packed_b_size);
}
thread_local_packed_b = Some(packed_b);
thread_local_packed_b.as_deref().unwrap()
}),
Expand Down Expand Up @@ -834,13 +839,19 @@ fn gemm_impl<K: Kernel, const MR_NR: usize>(
let packed_a = match a {
GemmInputA::Unpacked(a) => PACKED_A.with(|cell| {
let mut packed_a = cell.take();
packed_a.resize(packed_a_size, 0.);
packed_a.clear();
packed_a.reserve(packed_a_size);
pack_a_block::<K>(
&mut packed_a,
&mut packed_a.spare_capacity_mut()[..packed_a_size],
a,
row_start..row_end,
depth_range.clone(),
);
// Safety: `pack_a_block` will have initialized
// `packed_a_size` elements.
unsafe {
packed_a.set_len(packed_a_size);
}
thread_local_packed_a = Some(packed_a);
thread_local_packed_a.as_deref().unwrap()
}),
Expand Down Expand Up @@ -1706,29 +1717,6 @@ mod tests {
}
}

#[test]
#[ignore]
fn bench_pack_a() {
let gemm = GemmExecutor::new();
let mut rng = XorShiftRng::new(1234);
let m = 1024;
let n = 1024;
let iters = 1000;
let a = NdTensor::rand([m, n], &mut rng);

// Re-use a buffer across each call, so we measure the packing cost and
// not the allocation cost. This is appropriate for measuring packing
// cost for GEMM ops with unpacked inputs, as a thread-local packing
// buffer is allocated once and then re-used.
let mut packed_data = Vec::new();

run_bench(10, &format!("m {} n {} iters {}", m, n, iters), || {
for _i in 0..iters {
gemm.prepack_a_into(a.view(), &mut packed_data);
}
});
}

// Like `bench_pack_a`, but this does include allocation costs, so is
// relevant for ops which prepack inputs (eg. batched matmul).
#[test]
Expand Down
42 changes: 36 additions & 6 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::mem::MaybeUninit;
use std::ops::Range;

use rten_tensor::{Matrix, MatrixLayout};
Expand Down Expand Up @@ -257,10 +258,38 @@ pub trait Kernel {

/// Object-safe trait for performing matrix multiplications and packing inputs
/// with a specific kernel.
pub trait GemmOps: Sync {
///
/// # Safety
///
/// The packing functions must initialize all elements of the output buffers
/// passed to them.
pub unsafe trait GemmOps: Sync {
fn name(&self) -> &str;
fn pack_a_block(&self, out: &mut [f32], a: Matrix, rows: Range<usize>, cols: Range<usize>);
fn pack_b_block(&self, out: &mut [f32], a: Matrix, rows: Range<usize>, cols: Range<usize>);

/// Pack elements of `a` into a packing buffer for use by the matrix
/// multiplication kernel.
///
/// Implementations must initialize all elements of `out`.
fn pack_a_block(
&self,
out: &mut [MaybeUninit<f32>],
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
);

/// Pack elements of `b` into a packing buffer for use by the matrix
/// multiplication kernel.
///
/// Implementations must initialize all elements of `out`.
fn pack_b_block(
&self,
out: &mut [MaybeUninit<f32>],
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
);

fn gemm(
&self,
out_data: &mut [f32],
Expand All @@ -278,14 +307,15 @@ pub trait GemmOps: Sync {
/// stable Rust.
macro_rules! impl_gemmops {
($kernel:ident) => {
impl crate::gemm::kernels::GemmOps for $kernel {
// Safety - The packing functions initialize all elements of their output.
unsafe impl crate::gemm::kernels::GemmOps for $kernel {
fn name(&self) -> &str {
<$kernel as crate::gemm::kernels::Kernel>::name()
}

fn pack_a_block(
&self,
out: &mut [f32],
out: &mut [std::mem::MaybeUninit<f32>],
a: rten_tensor::Matrix,
rows: std::ops::Range<usize>,
cols: std::ops::Range<usize>,
Expand All @@ -295,7 +325,7 @@ macro_rules! impl_gemmops {

fn pack_b_block(
&self,
out: &mut [f32],
out: &mut [std::mem::MaybeUninit<f32>],
a: rten_tensor::Matrix,
rows: std::ops::Range<usize>,
cols: std::ops::Range<usize>,
Expand Down
65 changes: 51 additions & 14 deletions src/gemm/packing.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::mem::MaybeUninit;
use std::ops::Range;

use rten_tensor::{Matrix, MatrixLayout};
Expand All @@ -11,7 +12,17 @@ use super::Kernel;
/// row panels. Each row panel has size `K::MR * cols.len()` and uses
/// column-major order. If `rows.len()` is not a multiple of `K::MR`, the
/// final panel is zero-padded.
pub fn pack_a_block<K: Kernel>(out: &mut [f32], a: Matrix, rows: Range<usize>, cols: Range<usize>) {
///
/// # Safety
///
/// When this function returns, all elements of `out` will have been initialized
/// either to a value from `a`, or zero.
pub fn pack_a_block<K: Kernel>(
out: &mut [MaybeUninit<f32>],
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let a_rows = rows.len();
let a_cols = cols.len();

Expand Down Expand Up @@ -40,8 +51,8 @@ pub fn pack_a_block<K: Kernel>(out: &mut [f32], a: Matrix, rows: Range<usize>, c
for row in 0..K::MR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
*out.get_unchecked_mut(panel_offset + col * K::MR + row) =
*a_data.get_unchecked(a_offset + row * row_stride + col);
out.get_unchecked_mut(panel_offset + col * K::MR + row)
.write(*a_data.get_unchecked(a_offset + row * row_stride + col));
}
}
}
Expand All @@ -50,8 +61,12 @@ pub fn pack_a_block<K: Kernel>(out: &mut [f32], a: Matrix, rows: Range<usize>, c
for row in 0..K::MR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
*out.get_unchecked_mut(panel_offset + col * K::MR + row) = *a_data
.get_unchecked(a_offset + row * row_stride + col * col_stride);
out.get_unchecked_mut(panel_offset + col * K::MR + row)
.write(
*a_data.get_unchecked(
a_offset + row * row_stride + col * col_stride,
),
);
}
}
}
Expand All @@ -62,15 +77,21 @@ pub fn pack_a_block<K: Kernel>(out: &mut [f32], a: Matrix, rows: Range<usize>, c
let out_col_offset = panel_offset + col * K::MR;
for row in 0..K::MR {
let a_row = rows.start + panel_start_row + row;
out[out_col_offset + row] = if a_row < rows.end {
out[out_col_offset + row].write(if a_row < rows.end {
a_data[a_row * row_stride + (cols.start + col) * col_stride]
} else {
0.0
};
});
}
}
}
}

// Initialize any spare capacity in the buffer.
let n_init = n_panels * a_cols * K::MR;
for x in &mut out[n_init..] {
x.write(0.);
}
}

/// Pack a block of the "B" matrix for use by kernel K.
Expand All @@ -79,7 +100,17 @@ pub fn pack_a_block<K: Kernel>(out: &mut [f32], a: Matrix, rows: Range<usize>, c
/// K::NR)` column panels. Each column panel has size `rows.len() *
/// K::NR` and uses row-major order. If `cols.len()` is not a multiple of
/// `K::NR`, the final panel is zero-padded.
pub fn pack_b_block<K: Kernel>(out: &mut [f32], b: Matrix, rows: Range<usize>, cols: Range<usize>) {
///
/// # Safety
///
/// When this function returns, all elements of `out` will have been initialized
/// either to a value from `b`, or zero.
pub fn pack_b_block<K: Kernel>(
out: &mut [MaybeUninit<f32>],
b: Matrix,
rows: Range<usize>,
cols: Range<usize>,
) {
let b_cols = cols.len();
let b_rows = rows.len();
let b_row_stride = b.row_stride();
Expand Down Expand Up @@ -113,8 +144,8 @@ pub fn pack_b_block<K: Kernel>(out: &mut [f32], b: Matrix, rows: Range<usize>, c
for col in 0..K::NR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
*out.get_unchecked_mut(out_offset + col) =
*b_data.get_unchecked(in_offset + col);
out.get_unchecked_mut(out_offset + col)
.write(*b_data.get_unchecked(in_offset + col));
}
}
}
Expand All @@ -125,8 +156,8 @@ pub fn pack_b_block<K: Kernel>(out: &mut [f32], b: Matrix, rows: Range<usize>, c
for col in 0..K::NR {
// Safety: Indexes are less than lengths asserted above.
unsafe {
*out.get_unchecked_mut(out_offset + col) =
*b_data.get_unchecked(in_offset + col * b_col_stride);
out.get_unchecked_mut(out_offset + col)
.write(*b_data.get_unchecked(in_offset + col * b_col_stride));
}
}
}
Expand All @@ -142,13 +173,19 @@ pub fn pack_b_block<K: Kernel>(out: &mut [f32], b: Matrix, rows: Range<usize>, c
let b_offset =
b_row_offset + (cols.start + panel_start_col + col) * b_col_stride;

out[out_row_offset + col] = if out_col < b_cols {
out[out_row_offset + col].write(if out_col < b_cols {
b_data[b_offset]
} else {
0.0
};
});
}
}
}
}

// Initialize any spare capacity in the buffer.
let n_init = n_panels * b_rows * K::NR;
for x in &mut out[n_init..] {
x.write(0.);
}
}

0 comments on commit ae2b8da

Please sign in to comment.