diff --git a/src/gemm.rs b/src/gemm.rs index 7e5c5b93..c19ace31 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -407,46 +407,36 @@ impl GemmExecutor { /// Prepack a matrix for use as the left-hand or "A" input. pub fn prepack_a(&self, a: Matrix) -> PackedAMatrix<'static> { - let mut data = Vec::new(); - let packed = self.prepack_a_into(a, &mut data); - let rows = packed.rows; - let cols = packed.cols; - let panel_len = packed.panel_len; - let row_blocks = packed.row_blocks; - - PackedAMatrix { - data: Cow::Owned(data), - rows, - cols, - panel_len, - row_blocks, - } - } - - /// Prepack a matrix for use as the left-hand or "A" input, re-using an - /// existing buffer which will be resized as needed. - fn prepack_a_into<'a>(&self, a: Matrix, out: &'a mut Vec) -> PackedAMatrix<'a> { let kc = depth_block_size(a.cols()); let mc = row_block_size(a.rows(), self.mr); let panel_len = kc * mc; let row_blocks = div_ceil(a.rows(), mc); let depth_blocks = div_ceil(a.cols(), kc); - out.resize(depth_blocks * row_blocks * panel_len, 0.); + let packed_len = depth_blocks * row_blocks * panel_len; + let mut data = Vec::with_capacity(packed_len); // Pack blocks in the order they will be accessed by the GEMM // implementation. - let mut out_panels = out.chunks_exact_mut(panel_len); + let mut out_panels = data.spare_capacity_mut()[..packed_len].chunks_exact_mut(panel_len); + let mut n_init = 0; for depth_range in range_chunks(0..a.cols(), kc) { for row_range in range_chunks(0..a.rows(), mc) { let out_panel = out_panels.next().unwrap(); self.kernel .pack_a_block(out_panel, a, row_range, depth_range.clone()); + n_init += out_panel.len(); } } + // Safety: We used `pack_a_block` to initialize `packed_len` elements. + assert!(n_init == packed_len); + unsafe { + data.set_len(packed_len); + } + PackedAMatrix { - data: Cow::Borrowed(out), + data: Cow::Owned(data), rows: a.rows(), cols: a.cols(), panel_len, @@ -462,19 +452,28 @@ impl GemmExecutor { let depth_blocks = div_ceil(a_cols, kc); let col_blocks = div_ceil(b.cols(), nc); - let mut out = vec![0.; col_blocks * depth_blocks * panel_len]; + let packed_len = col_blocks * depth_blocks * panel_len; + let mut out = Vec::with_capacity(packed_len); // Pack blocks in the order they will be accessed by the GEMM // implementation. - let mut out_panels = out.chunks_exact_mut(panel_len); + let mut out_panels = out.spare_capacity_mut()[..packed_len].chunks_exact_mut(panel_len); + let mut n_init = 0; for col_range in range_chunks(0..b.cols(), nc) { for depth_range in range_chunks(0..a_cols, kc) { let out_panel = out_panels.next().unwrap(); self.kernel .pack_b_block(out_panel, b, depth_range, col_range.clone()); + n_init += out_panel.len(); } } + // Safety: We used `pack_b_block` to initialize `packed_len` elements. + assert!(n_init == packed_len); + unsafe { + out.set_len(packed_len); + } + PackedBMatrix { data: out, rows: b.rows(), @@ -791,13 +790,19 @@ fn gemm_impl( let packed_b = match b { GemmInputB::Unpacked(b) => PACKED_B.with(|cell| { let mut packed_b = cell.take(); - packed_b.resize(packed_b_size, 0.); + packed_b.clear(); + packed_b.reserve(packed_b_size); pack_b_block::( - &mut packed_b, + &mut packed_b.spare_capacity_mut()[..packed_b_size], b, depth_range.clone(), col_start..col_end, ); + // Safety: pack_b_block initialized `packed_b_size` + // elements. + unsafe { + packed_b.set_len(packed_b_size); + } thread_local_packed_b = Some(packed_b); thread_local_packed_b.as_deref().unwrap() }), @@ -834,13 +839,19 @@ fn gemm_impl( let packed_a = match a { GemmInputA::Unpacked(a) => PACKED_A.with(|cell| { let mut packed_a = cell.take(); - packed_a.resize(packed_a_size, 0.); + packed_a.clear(); + packed_a.reserve(packed_a_size); pack_a_block::( - &mut packed_a, + &mut packed_a.spare_capacity_mut()[..packed_a_size], a, row_start..row_end, depth_range.clone(), ); + // Safety: `pack_a_block` will have initialized + // `packed_a_size` elements. + unsafe { + packed_a.set_len(packed_a_size); + } thread_local_packed_a = Some(packed_a); thread_local_packed_a.as_deref().unwrap() }), @@ -1706,29 +1717,6 @@ mod tests { } } - #[test] - #[ignore] - fn bench_pack_a() { - let gemm = GemmExecutor::new(); - let mut rng = XorShiftRng::new(1234); - let m = 1024; - let n = 1024; - let iters = 1000; - let a = NdTensor::rand([m, n], &mut rng); - - // Re-use a buffer across each call, so we measure the packing cost and - // not the allocation cost. This is appropriate for measuring packing - // cost for GEMM ops with unpacked inputs, as a thread-local packing - // buffer is allocated once and then re-used. - let mut packed_data = Vec::new(); - - run_bench(10, &format!("m {} n {} iters {}", m, n, iters), || { - for _i in 0..iters { - gemm.prepack_a_into(a.view(), &mut packed_data); - } - }); - } - // Like `bench_pack_a`, but this does include allocation costs, so is // relevant for ops which prepack inputs (eg. batched matmul). #[test] diff --git a/src/gemm/kernels.rs b/src/gemm/kernels.rs index 96392867..1de5568f 100644 --- a/src/gemm/kernels.rs +++ b/src/gemm/kernels.rs @@ -1,3 +1,4 @@ +use std::mem::MaybeUninit; use std::ops::Range; use rten_tensor::{Matrix, MatrixLayout}; @@ -257,10 +258,38 @@ pub trait Kernel { /// Object-safe trait for performing matrix multiplications and packing inputs /// with a specific kernel. -pub trait GemmOps: Sync { +/// +/// # Safety +/// +/// The packing functions must initialize all elements of the output buffers +/// passed to them. +pub unsafe trait GemmOps: Sync { fn name(&self) -> &str; - fn pack_a_block(&self, out: &mut [f32], a: Matrix, rows: Range, cols: Range); - fn pack_b_block(&self, out: &mut [f32], a: Matrix, rows: Range, cols: Range); + + /// Pack elements of `a` into a packing buffer for use by the matrix + /// multiplication kernel. + /// + /// Implementations must initialize all elements of `out`. + fn pack_a_block( + &self, + out: &mut [MaybeUninit], + a: Matrix, + rows: Range, + cols: Range, + ); + + /// Pack elements of `b` into a packing buffer for use by the matrix + /// multiplication kernel. + /// + /// Implementations must initialize all elements of `out`. + fn pack_b_block( + &self, + out: &mut [MaybeUninit], + a: Matrix, + rows: Range, + cols: Range, + ); + fn gemm( &self, out_data: &mut [f32], @@ -278,14 +307,15 @@ pub trait GemmOps: Sync { /// stable Rust. macro_rules! impl_gemmops { ($kernel:ident) => { - impl crate::gemm::kernels::GemmOps for $kernel { + // Safety - The packing functions initialize all elements of their output. + unsafe impl crate::gemm::kernels::GemmOps for $kernel { fn name(&self) -> &str { <$kernel as crate::gemm::kernels::Kernel>::name() } fn pack_a_block( &self, - out: &mut [f32], + out: &mut [std::mem::MaybeUninit], a: rten_tensor::Matrix, rows: std::ops::Range, cols: std::ops::Range, @@ -295,7 +325,7 @@ macro_rules! impl_gemmops { fn pack_b_block( &self, - out: &mut [f32], + out: &mut [std::mem::MaybeUninit], a: rten_tensor::Matrix, rows: std::ops::Range, cols: std::ops::Range, diff --git a/src/gemm/packing.rs b/src/gemm/packing.rs index 8b32448c..d0476a0d 100644 --- a/src/gemm/packing.rs +++ b/src/gemm/packing.rs @@ -1,3 +1,4 @@ +use std::mem::MaybeUninit; use std::ops::Range; use rten_tensor::{Matrix, MatrixLayout}; @@ -11,7 +12,17 @@ use super::Kernel; /// row panels. Each row panel has size `K::MR * cols.len()` and uses /// column-major order. If `rows.len()` is not a multiple of `K::MR`, the /// final panel is zero-padded. -pub fn pack_a_block(out: &mut [f32], a: Matrix, rows: Range, cols: Range) { +/// +/// # Safety +/// +/// When this function returns, all elements of `out` will have been initialized +/// either to a value from `a`, or zero. +pub fn pack_a_block( + out: &mut [MaybeUninit], + a: Matrix, + rows: Range, + cols: Range, +) { let a_rows = rows.len(); let a_cols = cols.len(); @@ -40,8 +51,8 @@ pub fn pack_a_block(out: &mut [f32], a: Matrix, rows: Range, c for row in 0..K::MR { // Safety: Indexes are less than lengths asserted above. unsafe { - *out.get_unchecked_mut(panel_offset + col * K::MR + row) = - *a_data.get_unchecked(a_offset + row * row_stride + col); + out.get_unchecked_mut(panel_offset + col * K::MR + row) + .write(*a_data.get_unchecked(a_offset + row * row_stride + col)); } } } @@ -50,8 +61,12 @@ pub fn pack_a_block(out: &mut [f32], a: Matrix, rows: Range, c for row in 0..K::MR { // Safety: Indexes are less than lengths asserted above. unsafe { - *out.get_unchecked_mut(panel_offset + col * K::MR + row) = *a_data - .get_unchecked(a_offset + row * row_stride + col * col_stride); + out.get_unchecked_mut(panel_offset + col * K::MR + row) + .write( + *a_data.get_unchecked( + a_offset + row * row_stride + col * col_stride, + ), + ); } } } @@ -62,15 +77,21 @@ pub fn pack_a_block(out: &mut [f32], a: Matrix, rows: Range, c let out_col_offset = panel_offset + col * K::MR; for row in 0..K::MR { let a_row = rows.start + panel_start_row + row; - out[out_col_offset + row] = if a_row < rows.end { + out[out_col_offset + row].write(if a_row < rows.end { a_data[a_row * row_stride + (cols.start + col) * col_stride] } else { 0.0 - }; + }); } } } } + + // Initialize any spare capacity in the buffer. + let n_init = n_panels * a_cols * K::MR; + for x in &mut out[n_init..] { + x.write(0.); + } } /// Pack a block of the "B" matrix for use by kernel K. @@ -79,7 +100,17 @@ pub fn pack_a_block(out: &mut [f32], a: Matrix, rows: Range, c /// K::NR)` column panels. Each column panel has size `rows.len() * /// K::NR` and uses row-major order. If `cols.len()` is not a multiple of /// `K::NR`, the final panel is zero-padded. -pub fn pack_b_block(out: &mut [f32], b: Matrix, rows: Range, cols: Range) { +/// +/// # Safety +/// +/// When this function returns, all elements of `out` will have been initialized +/// either to a value from `b`, or zero. +pub fn pack_b_block( + out: &mut [MaybeUninit], + b: Matrix, + rows: Range, + cols: Range, +) { let b_cols = cols.len(); let b_rows = rows.len(); let b_row_stride = b.row_stride(); @@ -113,8 +144,8 @@ pub fn pack_b_block(out: &mut [f32], b: Matrix, rows: Range, c for col in 0..K::NR { // Safety: Indexes are less than lengths asserted above. unsafe { - *out.get_unchecked_mut(out_offset + col) = - *b_data.get_unchecked(in_offset + col); + out.get_unchecked_mut(out_offset + col) + .write(*b_data.get_unchecked(in_offset + col)); } } } @@ -125,8 +156,8 @@ pub fn pack_b_block(out: &mut [f32], b: Matrix, rows: Range, c for col in 0..K::NR { // Safety: Indexes are less than lengths asserted above. unsafe { - *out.get_unchecked_mut(out_offset + col) = - *b_data.get_unchecked(in_offset + col * b_col_stride); + out.get_unchecked_mut(out_offset + col) + .write(*b_data.get_unchecked(in_offset + col * b_col_stride)); } } } @@ -142,13 +173,19 @@ pub fn pack_b_block(out: &mut [f32], b: Matrix, rows: Range, c let b_offset = b_row_offset + (cols.start + panel_start_col + col) * b_col_stride; - out[out_row_offset + col] = if out_col < b_cols { + out[out_row_offset + col].write(if out_col < b_cols { b_data[b_offset] } else { 0.0 - }; + }); } } } } + + // Initialize any spare capacity in the buffer. + let n_init = n_panels * b_rows * K::NR; + for x in &mut out[n_init..] { + x.write(0.); + } }