Skip to content

Commit

Permalink
Merge pull request #599 from robertknight/simd-src-dest
Browse files Browse the repository at this point in the history
Replace PtrLen/MutPtrLen with SrcDest for vectorized op inputs/outputs
  • Loading branch information
robertknight authored Feb 16, 2025
2 parents 024c1e9 + b3c18c4 commit 8674c35
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 184 deletions.
33 changes: 13 additions & 20 deletions rten-simd/src/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::mem::MaybeUninit;

use crate::functional::simd_map;
use crate::span::{MutPtrLen, PtrLen};
use crate::span::SrcDest;
use crate::SimdFloat;

/// Dispatches SIMD operations using the preferred SIMD types for the current
Expand Down Expand Up @@ -144,8 +144,8 @@ pub trait SimdUnaryOp {
where
Self: Sized,
{
let wrapped_op = SimdMapOp::wrap(input.into(), output.into(), self);
dispatch(wrapped_op)
let wrapped_op = SimdMapOp::wrap((input, output).into(), self);
dispatch(wrapped_op);
}

/// Apply a vectorized unary function to a mutable slice.
Expand All @@ -156,40 +156,33 @@ pub trait SimdUnaryOp {
where
Self: Sized,
{
let out: MutPtrLen<f32> = input.into();
let wrapped_op = SimdMapOp::wrap(input.into(), out.as_uninit(), self);
dispatch(wrapped_op)
let wrapped_op = SimdMapOp::wrap(input.into(), self);
dispatch(wrapped_op);
}
}

/// SIMD operation which applies a unary operator `Op` to all elements in
/// an input buffer using [`simd_map`].
pub struct SimdMapOp<'a, Op: SimdUnaryOp> {
input: PtrLen<f32>,
output: MutPtrLen<MaybeUninit<f32>>,
src_dest: SrcDest<'a, f32>,
op: &'a Op,
}

impl<'a, Op: SimdUnaryOp> SimdMapOp<'a, Op> {
pub fn wrap(
input: PtrLen<f32>,
output: MutPtrLen<MaybeUninit<f32>>,
op: &'a Op,
) -> SimdMapOp<'a, Op> {
SimdMapOp { input, output, op }
pub fn wrap(src_dest: SrcDest<'a, f32>, op: &'a Op) -> SimdMapOp<'a, Op> {
SimdMapOp { src_dest, op }
}
}

impl<Op: SimdUnaryOp> SimdOp for SimdMapOp<'_, Op> {
type Output = ();
impl<'a, Op: SimdUnaryOp> SimdOp for SimdMapOp<'a, Op> {
type Output = &'a mut [f32];

#[inline(always)]
unsafe fn eval<S: SimdFloat>(self) {
unsafe fn eval<S: SimdFloat>(self) -> Self::Output {
simd_map(
self.input,
self.output,
self.src_dest,
#[inline(always)]
|x: S| self.op.eval(x),
);
)
}
}
25 changes: 10 additions & 15 deletions rten-simd/src/functional.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Higher order functions (map, fold etc.) that use vectorized operations.
use std::mem::MaybeUninit;

use crate::span::{MutPtrLen, PtrLen};
use crate::span::SrcDest;
use crate::{Simd, SimdMask};

/// Apply a unary operation to each element in `input` and store the results
Expand All @@ -21,15 +19,10 @@ use crate::{Simd, SimdMask};
/// current system.
#[inline(always)]
pub unsafe fn simd_map<S: Simd, Op: FnMut(S) -> S>(
input: PtrLen<S::Elem>,
output: MutPtrLen<MaybeUninit<S::Elem>>,
mut src_dest: SrcDest<S::Elem>,
mut op: Op,
) {
assert!(input.len() == output.len());

let mut n = input.len();
let mut in_ptr = input.ptr();
let mut out_ptr = output.ptr();
) -> &mut [S::Elem] {
let (mut in_ptr, mut out_ptr, mut n) = src_dest.src_dest_ptr();

while n >= S::LEN {
let x = S::load(in_ptr);
Expand All @@ -46,6 +39,8 @@ pub unsafe fn simd_map<S: Simd, Op: FnMut(S) -> S>(
let y = op(x);
y.store_partial(out_ptr as *mut S::Elem, n);
}

src_dest.dest_assume_init()
}

/// Apply a vectorized fold operation over `xs`. If the length of `xs` is not
Expand All @@ -58,12 +53,12 @@ pub unsafe fn simd_map<S: Simd, Op: FnMut(S) -> S>(
/// current system.
#[inline(always)]
pub unsafe fn simd_fold<S: Simd, Op: Fn(S, S) -> S>(
xs: PtrLen<S::Elem>,
xs: &[S::Elem],
mut accum: S,
simd_op: Op,
) -> S {
let mut n = xs.len();
let mut x_ptr = xs.ptr();
let mut x_ptr = xs.as_ptr();

while n >= S::LEN {
let x = S::load(x_ptr);
Expand Down Expand Up @@ -92,12 +87,12 @@ pub unsafe fn simd_fold<S: Simd, Op: Fn(S, S) -> S>(
/// current system.
#[inline(always)]
pub unsafe fn simd_fold_array<S: Simd, const N: usize, Op: Fn([S; N], S) -> [S; N]>(
xs: PtrLen<S::Elem>,
xs: &[S::Elem],
mut accum: [S; N],
simd_op: Op,
) -> [S; N] {
let mut n = xs.len();
let mut x_ptr = xs.ptr();
let mut x_ptr = xs.as_ptr();

while n >= S::LEN {
let x = S::load(x_ptr);
Expand Down
155 changes: 59 additions & 96 deletions rten-simd/src/span.rs
Original file line number Diff line number Diff line change
@@ -1,127 +1,90 @@
//! Slice-like types without the restrictions on aliasing.
//! Slice-like types used as inputs and outputs for vectorized operations.
use std::mem::{transmute, MaybeUninit};

/// Const pointer to a range of `T`s.
enum SrcDestInner<'a, T> {
InOut(&'a [T], &'a mut [MaybeUninit<T>]),
InMut(&'a mut [T]),
}

/// Input-output buffer for vectorized operations.
///
/// This is like an `&[T]`, but without the guarantee that no mutable aliases
/// exist. This is useful as it enables re-using the same unsafe code for
/// mutating and non-mutating variants of a function.
#[derive(Copy, Clone)]
pub struct PtrLen<T> {
ptr: *const T,
len: usize,
/// This can either be a single mutable buffer for operations that execute
/// in-place (`&mut [T]`) or a pair of input and output buffers where the
/// output is uninitialized (`([T], &mut [MaybeUninit<T>])`) and both buffers
/// must have the same length.
pub struct SrcDest<'a, T: Copy> {
inner: SrcDestInner<'a, T>,
}

impl<T> PtrLen<T> {
pub fn ptr(&self) -> *const T {
self.ptr
impl<'a, T: Copy> SrcDest<'a, T> {
/// Return the source slice.
pub fn src(&self) -> &[T] {
match &self.inner {
SrcDestInner::InOut(src, _dest) => src,
SrcDestInner::InMut(src_mut) => src_mut,
}
}

/// Return the length of the input and output slices.
pub fn len(&self) -> usize {
self.len
self.src().len()
}

/// Return true if the input and output slices are empty.
pub fn is_empty(&self) -> bool {
self.len == 0
}
}

impl<'a, T> From<&'a [T]> for PtrLen<T> {
fn from(val: &'a [T]) -> PtrLen<T> {
PtrLen {
ptr: val.as_ptr(),
len: val.len(),
}
}
}

impl<'a, T> From<&'a mut [T]> for PtrLen<T> {
fn from(val: &'a mut [T]) -> PtrLen<T> {
PtrLen {
ptr: val.as_ptr(),
len: val.len(),
}
self.src().is_empty()
}
}

impl<T> From<MutPtrLen<T>> for PtrLen<T> {
fn from(val: MutPtrLen<T>) -> PtrLen<T> {
PtrLen {
ptr: val.ptr,
len: val.len,
/// Return source and destination slice pointers and the length.
///
/// The source and destination will either alias, or the destination will
/// be a non-aliasing, uninitialized slice.
pub fn src_dest_ptr(&mut self) -> (*const T, *mut MaybeUninit<T>, usize) {
match &mut self.inner {
SrcDestInner::InOut(src, dest) => (src.as_ptr(), dest.as_mut_ptr(), src.len()),
SrcDestInner::InMut(src) => (
src.as_ptr(),
src.as_mut_ptr() as *mut MaybeUninit<T>,
src.len(),
),
}
}
}

/// Mutable pointer to a range of `T`s.
///
/// This is like an `&mut [T]`, but without the guarantee that no aliases exist.
#[derive(Copy, Clone)]
pub struct MutPtrLen<T> {
ptr: *mut T,
len: usize,
}

impl<T> MutPtrLen<T> {
pub fn ptr(&self) -> *mut T {
self.ptr
}

pub fn len(&self) -> usize {
self.len
}

pub fn is_empty(&self) -> bool {
self.len == 0
}
}

impl<T> MutPtrLen<MaybeUninit<T>> {
/// Promise that the span of `T`s that are pointed to have been initialized.
/// Return the initialized destination slice.
///
/// # Safety
///
/// The caller must ensure that all elements referenced by this range have
/// been initialized.
pub unsafe fn assume_init(self) -> MutPtrLen<T> {
MutPtrLen {
ptr: unsafe { transmute::<*mut MaybeUninit<T>, *mut T>(self.ptr) },
len: self.len,
/// If this instance was constructed with an uninitialized destination
/// buffer, all elements must have been initialized before this is called.
pub unsafe fn dest_assume_init(self) -> &'a mut [T] {
match self.inner {
SrcDestInner::InOut(_src, dest) => transmute::<&mut [MaybeUninit<T>], &mut [T]>(dest),
SrcDestInner::InMut(src) => src,
}
}
}

impl<T> MutPtrLen<T> {
/// Transmute a span of initialized `T`s to uninitialized `T`s.
pub fn as_uninit(self) -> MutPtrLen<MaybeUninit<T>>
where
T: Copy,
{
MutPtrLen {
ptr: unsafe { transmute::<*mut T, *mut MaybeUninit<T>>(self.ptr) },
len: self.len,
impl<'a, T: Copy> From<(&'a [T], &'a mut [MaybeUninit<T>])> for SrcDest<'a, T> {
fn from(val: (&'a [T], &'a mut [MaybeUninit<T>])) -> Self {
let (src, dest) = val;
assert_eq!(
src.len(),
dest.len(),
"src len {} != dest len {}",
src.len(),
dest.len(),
);
SrcDest {
inner: SrcDestInner::InOut(src, dest),
}
}

/// Convert `self` into a slice.
///
/// # Safety
///
/// The caller must uphold all the invariants specified in
/// [`std::slice::from_raw_parts_mut`]. In particular all elements must be
/// initialized, and there must be no other mutable references to any
/// elements in the slice.
pub unsafe fn as_slice<'a>(self) -> &'a mut [T] {
std::slice::from_raw_parts_mut(self.ptr, self.len)
}
}

impl<'a, T> From<&'a mut [T]> for MutPtrLen<T> {
fn from(val: &'a mut [T]) -> MutPtrLen<T> {
MutPtrLen {
ptr: val.as_mut_ptr(),
len: val.len(),
impl<'a, T: Copy> From<&'a mut [T]> for SrcDest<'a, T> {
fn from(val: &'a mut [T]) -> Self {
SrcDest {
inner: SrcDestInner::InMut(val),
}
}
}
2 changes: 1 addition & 1 deletion rten-vecmath/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl SimdOp for MinMax<'_> {
#[inline(always)]
unsafe fn eval<S: SimdFloat>(self) -> Self::Output {
let [vec_min, vec_max] = simd_fold_array(
self.input.into(),
self.input,
[S::splat(f32::MAX), S::splat(f32::MIN)],
#[inline(always)]
|[min, max], x| [x.min(min), x.max(max)],
Expand Down
Loading

0 comments on commit 8674c35

Please sign in to comment.