Skip to content

Commit

Permalink
Merge pull request #166 from robertknight/faster-non-contiguous-copy
Browse files Browse the repository at this point in the history
Optimize `TensorBase::copy_from` for non-contiguous `self`
  • Loading branch information
robertknight authored May 10, 2024
2 parents 6050f4c + d9d6c02 commit 41d8653
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 20 deletions.
83 changes: 71 additions & 12 deletions rten-tensor/src/transpose.rs → rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::mem::MaybeUninit;
use std::ops::Range;

use crate::{AsView, Layout};
use crate::{Matrix, MatrixLayout, MatrixMut, NdTensorView, NdTensorViewMut, TensorView};
use crate::{
Matrix, MatrixLayout, MatrixMut, NdTensorView, NdTensorViewMut, TensorView, TensorViewMut,
};

/// Iterator returned by [range_chunks].
pub struct RangeChunks {
Expand Down Expand Up @@ -91,10 +93,11 @@ fn copy_blocked<T: Clone>(src: Matrix<T>, mut dest: MatrixMut<MaybeUninit<T>>) {
}
}

/// Copy elements of `src` into `dest` in contiguous order.
/// Copy elements of `src` into a contiguous destination slice with the same
/// length.
///
/// Returns `dest` as an initialized slice.
pub fn copy_contiguous<'a, T: Clone>(
pub fn copy_into_slice<'a, T: Clone>(
src: TensorView<T>,
dest: &'a mut [MaybeUninit<T>],
) -> &'a [T] {
Expand Down Expand Up @@ -160,9 +163,50 @@ pub fn copy_contiguous<'a, T: Clone>(
}
}

/// Clone elements of `src` into `dest`.
///
/// This is functionally equivalent to:
///
/// ```text
/// src.iter().zip(dest.iter_mut()).for_each(|(y, x)| *y = x.clone())
/// ```
///
/// But more efficient, especially when `src` or `dest` are not contiguous.
pub fn copy_into<T: Clone>(mut src: TensorView<T>, mut dest: TensorViewMut<T>) {
assert!(src.shape() == dest.shape());

while src.ndim() < 4 {
src.insert_axis(0);
dest.insert_axis(0);
}

// Efficiency could be improved here by sorting dims so that those with
// the smallest stride are innermost. Also it could use the blocked copy
// that `copy_into_slice` uses to avoid cache conflicts when inputs are
// transposed.

src.inner_iter::<4>()
.zip(dest.inner_iter_mut::<4>())
.for_each(|(src, mut dest)| {
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
for i2 in 0..src.size(2) {
for i3 in 0..src.size(3) {
unsafe {
*dest.get_unchecked_mut([i0, i1, i2, i3]) =
src.get_unchecked([i0, i1, i2, i3]).clone();
}
}
}
}
}
});
}

#[cfg(test)]
mod tests {
use super::copy_contiguous;
use super::{copy_into, copy_into_slice};
use crate::rng::XorShiftRng;
use crate::{AsView, Layout, Tensor, TensorView};

/// Return the elements of `src` as a contiguous vector, in the same order they
Expand All @@ -173,10 +217,10 @@ mod tests {
///
/// This is equivalent to `src.iter().cloned().collect::<Vec<_>>()` but
/// faster.
fn contiguous_data<T: Clone>(src: TensorView<T>) -> Vec<T> {
fn copy_into_vec<T: Clone>(src: TensorView<T>) -> Vec<T> {
let src_len = src.len();
let mut result = Vec::with_capacity(src_len);
copy_contiguous(src, &mut result.spare_capacity_mut()[..src_len]);
copy_into_slice(src, &mut result.spare_capacity_mut()[..src_len]);

// Safety: `copy_contiguous` initialized `src_len` elements of result.
unsafe { result.set_len(src_len) };
Expand All @@ -185,16 +229,31 @@ mod tests {
}

#[test]
fn test_contiguous_data() {
fn test_copy_into() {
let mut rng = XorShiftRng::new(1234);
for ndim in 0..5 {
let shape: Vec<_> = (0..ndim).map(|d| d + 1).collect();
let src = Tensor::rand(&shape, &mut rng);
let src = src.transposed();

let mut dest = Tensor::zeros(src.shape());
copy_into(src.view(), dest.view_mut());

assert_eq!(dest, src);
}
}

#[test]
fn test_copy_into_slice() {
// <= 4 dims
let x = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
assert_eq!(contiguous_data(x.view()), [1, 2, 3, 4]);
assert_eq!(contiguous_data(x.transposed()), [1, 3, 2, 4]);
assert_eq!(copy_into_vec(x.view()), [1, 2, 3, 4]);
assert_eq!(copy_into_vec(x.transposed()), [1, 3, 2, 4]);

// > 4 dims
let x = Tensor::from_data(&[1, 1, 1, 2, 2], vec![1, 2, 3, 4]);
assert_eq!(contiguous_data(x.view()), [1, 2, 3, 4]);
assert_eq!(contiguous_data(x.transposed()), [1, 3, 2, 4]);
assert_eq!(copy_into_vec(x.view()), [1, 2, 3, 4]);
assert_eq!(copy_into_vec(x.transposed()), [1, 3, 2, 4]);

// Transposed matrices of varying sizes. This includes:
//
Expand All @@ -205,7 +264,7 @@ mod tests {
for size in [0usize, 2, 4, 8, 15, 16, 32, 64, 65, 68] {
let x = Tensor::<i32>::arange(0, (size * size) as i32, None);
let x = x.reshaped([size, size]);
let transposed = contiguous_data(x.transposed().as_dyn());
let transposed = copy_into_vec(x.transposed().as_dyn());
let expected = x.transposed().iter().copied().collect::<Vec<_>>();
assert_eq!(transposed, expected);
}
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
//! }
//! ```
mod copy;
mod errors;
mod index_iterator;
mod iterators;
Expand All @@ -49,7 +50,6 @@ mod overlap;
mod slice_range;
mod storage;
mod tensor;
mod transpose;

/// Trait for sources of random data for tensors, for use with [Tensor::rand].
pub trait RandomSource<T> {
Expand Down
12 changes: 5 additions & 7 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::Cow;
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut, Range};

use crate::copy::{copy_into, copy_into_slice};
use crate::errors::{DimensionError, FromDataError, SliceError};
use crate::iterators::{
AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterDyn,
Expand All @@ -12,7 +13,6 @@ use crate::layout::{
OverlapPolicy, ResizeLayout,
};
use crate::storage::{CowData, IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
use crate::transpose::copy_contiguous;
use crate::{Alloc, GlobalAlloc, IntoSliceItems, RandomSource, SliceItem};

/// The base type for multi-dimensional arrays. This consists of storage for
Expand Down Expand Up @@ -479,12 +479,10 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
}

// Copy source into destination in contiguous order.
copy_contiguous(other.as_dyn(), uninit_dest);
copy_into_slice(other.as_dyn(), uninit_dest);
}
} else {
for (out, x) in self.iter_mut().zip(other.iter()) {
*out = x.clone();
}
copy_into(other.as_dyn(), self.as_dyn_mut());
}
}

Expand Down Expand Up @@ -954,7 +952,7 @@ where
let data: &[MaybeUninit<T>] = unsafe { std::mem::transmute(data) };
self.data.as_mut().clone_from_slice(data);
} else {
copy_contiguous(other.as_dyn(), self.data.as_mut());
copy_into_slice(other.as_dyn(), self.data.as_mut());
}
unsafe { self.assume_init() }
}
Expand Down Expand Up @@ -1398,7 +1396,7 @@ impl<T, S: Storage<Elem = T>, L: MutLayout + Clone> AsView for TensorBase<S, L>
if let Some(data) = self.data() {
buf.extend_from_slice(data);
} else {
copy_contiguous(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);
copy_into_slice(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);

// Safety: We initialized `len` elements.
unsafe { buf.set_len(len) }
Expand Down

0 comments on commit 41d8653

Please sign in to comment.