diff --git a/src/ops/layout.rs b/src/ops/layout.rs index 143fe76c..a1f1ab07 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -912,6 +912,36 @@ mod tests { ); } + use rten_tensor::{NdTensorView, TensorView, TensorViewMut}; + + fn reference_transpose_into<'a, T: Clone>(src: TensorView, mut dest: TensorViewMut) { + // Merge axes to maximize iteration count of the innermost loops. + let mut src = src.clone(); + src.merge_axes(); + + while src.ndim() < 4 { + src.insert_axis(0); + } + + let dest_data = dest.data_mut().unwrap(); + + let src: NdTensorView = src.nd_view(); + let mut dest_offset = 0; + for i0 in 0..src.size(0) { + for i1 in 0..src.size(1) { + for i2 in 0..src.size(2) { + for i3 in 0..src.size(3) { + unsafe { + let elt = src.get_unchecked([i0, i1, i2, i3]).clone(); + *dest_data.get_unchecked_mut(dest_offset) = elt; + dest_offset += 1; + } + } + } + } + } + } + #[test] #[ignore] fn bench_transpose() { @@ -937,6 +967,10 @@ mod tests { // experience slowdown due to poor cache usage. There can also be // issues to a lesser extent with sizes which are a multiple of // (cache_line_size / element_size). + Case { + shape: &[128, 128], + perm: &[1, 0], + }, Case { shape: &[256, 256], perm: &[1, 0], @@ -950,6 +984,14 @@ mod tests { perm: &[1, 0], }, // Matrix transpose with non power-of-2 sizes. + Case { + shape: &[127, 127], + perm: &[1, 0], + }, + Case { + shape: &[255, 255], + perm: &[1, 0], + }, Case { shape: &[513, 513], perm: &[1, 0], @@ -983,24 +1025,33 @@ mod tests { for Case { shape, perm } in cases { let tensor = Tensor::rand(shape, &mut rng); + let mut dest = Tensor::zeros(shape); // Do a simple copy. This provides a lower-bound on how fast // transpose can operate. - let copy_stats = run_bench(100, format!("copy {:?}", shape), || { - tensor.view().to_tensor(); + let copy_stats = run_bench(100, None, || { + dest.copy_from(&tensor.view()); }); + assert_eq!(dest, tensor); - let transpose_stats = run_bench( - 100, - format!("transpose {:?} perm {:?}", shape, perm), - || { - transpose(tensor.view(), Some(perm)).unwrap(); - }, - ); + let reference_transpose_stats = run_bench(100, None, || { + let transposed = tensor.permuted(perm); + reference_transpose_into(transposed.view(), dest.reshaped_mut(transposed.shape())); + }); + + let transpose_stats = run_bench(100, None, || { + let transposed = tensor.permuted(perm); + dest.reshape(transposed.shape()); + dest.copy_from(&transposed); + }); + assert_eq!(dest, tensor.permuted(perm)); let transpose_overhead = (transpose_stats.mean - copy_stats.mean).max(0.) / copy_stats.mean; - println!("transpose {:?} overhead {}", shape, transpose_overhead); + println!( + "transpose shape {:?} perm {:?} copy {:.3}ms ref transpose {:.3}ms opt transpose {:.3}ms overhead {}", + shape, perm, copy_stats.median, reference_transpose_stats.median, transpose_stats.median, transpose_overhead + ); } } }