Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TensorBase::copy_from for non-contiguous self #166

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 71 additions & 12 deletions rten-tensor/src/transpose.rs → rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::mem::MaybeUninit;
use std::ops::Range;

use crate::{AsView, Layout};
use crate::{Matrix, MatrixLayout, MatrixMut, NdTensorView, NdTensorViewMut, TensorView};
use crate::{
Matrix, MatrixLayout, MatrixMut, NdTensorView, NdTensorViewMut, TensorView, TensorViewMut,
};

/// Iterator returned by [range_chunks].
pub struct RangeChunks {
Expand Down Expand Up @@ -91,10 +93,11 @@ fn copy_blocked<T: Clone>(src: Matrix<T>, mut dest: MatrixMut<MaybeUninit<T>>) {
}
}

/// Copy elements of `src` into `dest` in contiguous order.
/// Copy elements of `src` into a contiguous destination slice with the same
/// length.
///
/// Returns `dest` as an initialized slice.
pub fn copy_contiguous<'a, T: Clone>(
pub fn copy_into_slice<'a, T: Clone>(
src: TensorView<T>,
dest: &'a mut [MaybeUninit<T>],
) -> &'a [T] {
Expand Down Expand Up @@ -160,9 +163,50 @@ pub fn copy_contiguous<'a, T: Clone>(
}
}

/// Clone elements of `src` into `dest`.
///
/// This is functionally equivalent to:
///
/// ```text
/// src.iter().zip(dest.iter_mut()).for_each(|(y, x)| *y = x.clone())
/// ```
///
/// But more efficient, especially when `src` or `dest` are not contiguous.
pub fn copy_into<T: Clone>(mut src: TensorView<T>, mut dest: TensorViewMut<T>) {
assert!(src.shape() == dest.shape());

while src.ndim() < 4 {
src.insert_axis(0);
dest.insert_axis(0);
}

// Efficiency could be improved here by sorting dims so that those with
// the smallest stride are innermost. Also it could use the blocked copy
// that `copy_into_slice` uses to avoid cache conflicts when inputs are
// transposed.

src.inner_iter::<4>()
.zip(dest.inner_iter_mut::<4>())
.for_each(|(src, mut dest)| {
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
for i2 in 0..src.size(2) {
for i3 in 0..src.size(3) {
unsafe {
*dest.get_unchecked_mut([i0, i1, i2, i3]) =
src.get_unchecked([i0, i1, i2, i3]).clone();
}
}
}
}
}
});
}

#[cfg(test)]
mod tests {
use super::copy_contiguous;
use super::{copy_into, copy_into_slice};
use crate::rng::XorShiftRng;
use crate::{AsView, Layout, Tensor, TensorView};

/// Return the elements of `src` as a contiguous vector, in the same order they
Expand All @@ -173,10 +217,10 @@ mod tests {
///
/// This is equivalent to `src.iter().cloned().collect::<Vec<_>>()` but
/// faster.
fn contiguous_data<T: Clone>(src: TensorView<T>) -> Vec<T> {
fn copy_into_vec<T: Clone>(src: TensorView<T>) -> Vec<T> {
let src_len = src.len();
let mut result = Vec::with_capacity(src_len);
copy_contiguous(src, &mut result.spare_capacity_mut()[..src_len]);
copy_into_slice(src, &mut result.spare_capacity_mut()[..src_len]);

// Safety: `copy_contiguous` initialized `src_len` elements of result.
unsafe { result.set_len(src_len) };
Expand All @@ -185,16 +229,31 @@ mod tests {
}

#[test]
fn test_contiguous_data() {
fn test_copy_into() {
let mut rng = XorShiftRng::new(1234);
for ndim in 0..5 {
let shape: Vec<_> = (0..ndim).map(|d| d + 1).collect();
let src = Tensor::rand(&shape, &mut rng);
let src = src.transposed();

let mut dest = Tensor::zeros(src.shape());
copy_into(src.view(), dest.view_mut());

assert_eq!(dest, src);
}
}

#[test]
fn test_copy_into_slice() {
// <= 4 dims
let x = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
assert_eq!(contiguous_data(x.view()), [1, 2, 3, 4]);
assert_eq!(contiguous_data(x.transposed()), [1, 3, 2, 4]);
assert_eq!(copy_into_vec(x.view()), [1, 2, 3, 4]);
assert_eq!(copy_into_vec(x.transposed()), [1, 3, 2, 4]);

// > 4 dims
let x = Tensor::from_data(&[1, 1, 1, 2, 2], vec![1, 2, 3, 4]);
assert_eq!(contiguous_data(x.view()), [1, 2, 3, 4]);
assert_eq!(contiguous_data(x.transposed()), [1, 3, 2, 4]);
assert_eq!(copy_into_vec(x.view()), [1, 2, 3, 4]);
assert_eq!(copy_into_vec(x.transposed()), [1, 3, 2, 4]);

// Transposed matrices of varying sizes. This includes:
//
Expand All @@ -205,7 +264,7 @@ mod tests {
for size in [0usize, 2, 4, 8, 15, 16, 32, 64, 65, 68] {
let x = Tensor::<i32>::arange(0, (size * size) as i32, None);
let x = x.reshaped([size, size]);
let transposed = contiguous_data(x.transposed().as_dyn());
let transposed = copy_into_vec(x.transposed().as_dyn());
let expected = x.transposed().iter().copied().collect::<Vec<_>>();
assert_eq!(transposed, expected);
}
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
//! }
//! ```

mod copy;
mod errors;
mod index_iterator;
mod iterators;
Expand All @@ -49,7 +50,6 @@ mod overlap;
mod slice_range;
mod storage;
mod tensor;
mod transpose;

/// Trait for sources of random data for tensors, for use with [Tensor::rand].
pub trait RandomSource<T> {
Expand Down
12 changes: 5 additions & 7 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::Cow;
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut, Range};

use crate::copy::{copy_into, copy_into_slice};
use crate::errors::{DimensionError, FromDataError, SliceError};
use crate::iterators::{
AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterDyn,
Expand All @@ -12,7 +13,6 @@ use crate::layout::{
OverlapPolicy, ResizeLayout,
};
use crate::storage::{CowData, IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
use crate::transpose::copy_contiguous;
use crate::{Alloc, GlobalAlloc, IntoSliceItems, RandomSource, SliceItem};

/// The base type for multi-dimensional arrays. This consists of storage for
Expand Down Expand Up @@ -479,12 +479,10 @@ impl<S: StorageMut, L: MutLayout> TensorBase<S, L> {
}

// Copy source into destination in contiguous order.
copy_contiguous(other.as_dyn(), uninit_dest);
copy_into_slice(other.as_dyn(), uninit_dest);
}
} else {
for (out, x) in self.iter_mut().zip(other.iter()) {
*out = x.clone();
}
copy_into(other.as_dyn(), self.as_dyn_mut());
}
}

Expand Down Expand Up @@ -954,7 +952,7 @@ where
let data: &[MaybeUninit<T>] = unsafe { std::mem::transmute(data) };
self.data.as_mut().clone_from_slice(data);
} else {
copy_contiguous(other.as_dyn(), self.data.as_mut());
copy_into_slice(other.as_dyn(), self.data.as_mut());
}
unsafe { self.assume_init() }
}
Expand Down Expand Up @@ -1398,7 +1396,7 @@ impl<T, S: Storage<Elem = T>, L: MutLayout + Clone> AsView for TensorBase<S, L>
if let Some(data) = self.data() {
buf.extend_from_slice(data);
} else {
copy_contiguous(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);
copy_into_slice(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);

// Safety: We initialized `len` elements.
unsafe { buf.set_len(len) }
Expand Down
Loading