Skip to content

Commit

Permalink
Merge pull request #598 from robertknight/assume-init-refactor
Browse files Browse the repository at this point in the history
Refactor and export `AssumeInit` trait from rten-tensor and use it downstream
  • Loading branch information
robertknight authored Feb 15, 2025
2 parents 3d0fb4e + a7d7b0b commit 024c1e9
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 48 deletions.
89 changes: 89 additions & 0 deletions rten-tensor/src/assume_init.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use std::mem::MaybeUninit;

/// Trait for converting collections of uninitialized (`MaybeUninit<T>`) values
/// to collections of corresponding initializes values (`T`).
///
/// ## Example
///
/// ```
/// use std::mem::MaybeUninit;
/// use rten_tensor::AssumeInit;
///
/// fn scale_values<'a>(dst: &'a mut [MaybeUninit<f32>], src: &[f32], scale: f32) -> &'a mut [f32] {
/// for (y, x) in dst.into_iter().zip(src) {
/// y.write(x * scale);
/// }
/// // Safety: All elements have been initialized.
/// unsafe { dst.assume_init() }
/// }
///
/// let src = [1., 2., 3.];
/// let mut dst = [MaybeUninit::uninit(); 3];
/// let scaled = scale_values(&mut dst, &src, 2.);
/// assert_eq!(scaled, [2., 4., 6.]);
/// ```
pub trait AssumeInit {
/// The type of the initialized storage.
type Output;

/// Cast `self` to a collection of initialized values.
///
/// # Safety
///
/// The caller must guarantee that all elements have been initialized.
unsafe fn assume_init(self) -> Self::Output;
}

impl<T> AssumeInit for Vec<MaybeUninit<T>> {
type Output = Vec<T>;

unsafe fn assume_init(mut self) -> Self::Output {
let (ptr, len, capacity) = (self.as_mut_ptr(), self.len(), self.capacity());

// Don't drop self, as that would deallocate.
std::mem::forget(self);

// Safety: We're re-constructing a `Vec` with the same length and
// capacity and an element type that has the same size and alignment,
// just cast from uninitialized to initialized.
unsafe { Vec::from_raw_parts(ptr as *mut T, len, capacity) }
}
}

impl<'a, T> AssumeInit for &'a [MaybeUninit<T>] {
type Output = &'a [T];

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

impl<'a, T> AssumeInit for &'a mut [MaybeUninit<T>] {
type Output = &'a mut [T];

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

#[cfg(test)]
mod tests {
use std::mem::MaybeUninit;

use super::AssumeInit;

#[test]
fn test_assume_init_vec() {
let mut vec = vec![MaybeUninit::uninit(); 3];
vec.reserve(4);

for x in &mut vec {
x.write(2.);
}

let vec = unsafe { vec.assume_init() };
assert_eq!(vec.len(), 3);
assert_eq!(vec.capacity(), 7);
assert_eq!(vec, &[2., 2., 2.]);
}
}
2 changes: 2 additions & 0 deletions rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
//! }
//! ```
mod assume_init;
mod copy;
pub mod errors;
mod index_iterator;
Expand Down Expand Up @@ -119,6 +120,7 @@ impl Alloc for GlobalAlloc {
}
}

pub use assume_init::AssumeInit;
pub use index_iterator::{DynIndices, Indices, NdIndices};
pub use iterators::{
AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut, Iter, IterMut,
Expand Down
19 changes: 19 additions & 0 deletions rten-tensor/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::borrow::Cow;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::Range;

use crate::assume_init::AssumeInit;

/// Trait for backing storage used by tensors and views.
///
/// Mutable tensors have storage which also implement [`StorageMut`].
Expand Down Expand Up @@ -349,6 +352,14 @@ unsafe impl<T> Storage for ViewData<'_, T> {
}
}

impl<'a, T> AssumeInit for ViewData<'a, MaybeUninit<T>> {
type Output = ViewData<'a, T>;

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

/// Storage for a mutable tensor view.
///
/// This has the same representation in memory as a mutable slice: a pointer
Expand Down Expand Up @@ -420,6 +431,14 @@ unsafe impl<T> StorageMut for ViewMutData<'_, T> {
}
}

impl<'a, T> AssumeInit for ViewMutData<'a, MaybeUninit<T>> {
type Output = ViewMutData<'a, T>;

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

/// Tensor storage which may be either owned or borrowed.
///
/// The name is taken from [`std::borrow::Cow`] in the standard library,
Expand Down
39 changes: 1 addition & 38 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::Debug;
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut, Range};

use crate::assume_init::AssumeInit;
use crate::copy::{
copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice,
};
Expand Down Expand Up @@ -1204,44 +1205,6 @@ impl<T, L: MutLayout> TensorBase<CowData<'_, T>, L> {
}
}

/// Trait for converting potentially uninitialized tensor element storage to
/// initialized storage.
pub trait AssumeInit {
/// The type of the initialized storage.
type Output;

/// Promise that all elements in the storage have been initialized.
///
/// # Safety
///
/// The caller must guarantee that all elements have been initialized.
unsafe fn assume_init(self) -> Self::Output;
}

impl<T> AssumeInit for Vec<MaybeUninit<T>> {
type Output = Vec<T>;

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

impl<'a, T> AssumeInit for ViewData<'a, MaybeUninit<T>> {
type Output = ViewData<'a, T>;

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

impl<'a, T> AssumeInit for ViewMutData<'a, MaybeUninit<T>> {
type Output = ViewMutData<'a, T>;

unsafe fn assume_init(self) -> Self::Output {
std::mem::transmute(self)
}
}

impl<T, S: Storage<Elem = MaybeUninit<T>> + AssumeInit, L: Clone + MutLayout> TensorBase<S, L>
where
<S as AssumeInit>::Output: Storage<Elem = T>,
Expand Down
7 changes: 4 additions & 3 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use std::ops::{Add, Mul, Range};

use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{Alloc, GlobalAlloc, Matrix, MatrixLayout, MatrixMut, NdTensorView, Storage};
use rten_tensor::{
Alloc, AssumeInit, GlobalAlloc, Matrix, MatrixLayout, MatrixMut, NdTensorView, Storage,
};

use crate::iter_util::{range_chunks, MaybeParIter};
use crate::number::Identities;
Expand Down Expand Up @@ -601,8 +603,7 @@ fn gemv<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
}

// Safety: Calls to `gemv_kernel` initialized all output elements.
let out_chunk =
unsafe { std::mem::transmute::<&mut [MaybeUninit<OutT>], &mut [OutT]>(out_chunk) };
let out_chunk = unsafe { out_chunk.assume_init() };
match bias {
Some(BiasVector::Column(bias)) => {
let bias = bias[0];
Expand Down
4 changes: 2 additions & 2 deletions src/ops/concat.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::mem::MaybeUninit;

use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use rten_tensor::{AssumeInit, NdTensorView, Tensor, TensorView};

use smallvec::SmallVec;

Expand Down Expand Up @@ -155,7 +155,7 @@ where
dest.copy_from_slice(uninit_src);

// SAFETY: Valid elements have just been copied into `this` so it is initialized
unsafe { std::mem::transmute(dest) }
unsafe { dest.assume_init() }
}

/// Recursively tile (ie. repeatly copy) chunks of `input` to `output`.
Expand Down
6 changes: 3 additions & 3 deletions src/ops/conv/depthwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};

use rayon::prelude::*;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut};
use rten_tensor::{AssumeInit, NdTensor, NdTensorView, NdTensorViewMut};
use smallvec::SmallVec;

use crate::iter_util::{range_chunks, unroll_loop};
Expand Down Expand Up @@ -105,7 +105,7 @@ impl DepthwiseConvKernel<f32, f32, f32> for GenericDepthwiseConvKernel {
for x in out_row.iter_mut() {
x.write(out_init);
}
let out_row: &mut [f32] = unsafe { std::mem::transmute(out_row) };
let out_row: &mut [f32] = unsafe { out_row.assume_init() };

for k_y in 0..params.kernel_h {
let in_y = out_y * params.stride_h + k_y * params.dilation_y;
Expand Down Expand Up @@ -155,7 +155,7 @@ impl DepthwiseConvKernel<i8, u8, i32> for GenericDepthwiseConvKernel {
for x in out_row.iter_mut() {
x.write(out_init);
}
let out_row: &mut [i32] = unsafe { std::mem::transmute(out_row) };
let out_row: &mut [i32] = unsafe { out_row.assume_init() };

for k_y in 0..params.kernel_h {
let in_y = out_y * params.stride_h + k_y * params.dilation_y;
Expand Down
4 changes: 2 additions & 2 deletions src/ops/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::mem::MaybeUninit;

use rten_simd::dispatch::SimdOp;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Scalar, Tensor, TensorView};
use rten_tensor::{AssumeInit, NdTensor, NdTensorView, Scalar, Tensor, TensorView};
use rten_vecmath as vecmath;

use crate::ops::{
Expand Down Expand Up @@ -146,7 +146,7 @@ pub trait Quantize<To> {
for (x, y) in src.iter().zip(dest.iter_mut()) {
y.write(x.quantize(inv_scale, zero_point));
}
unsafe { std::mem::transmute::<&mut [MaybeUninit<To>], &mut [To]>(dest) }
unsafe { dest.assume_init() }
}
}

Expand Down

0 comments on commit 024c1e9

Please sign in to comment.