diff --git a/rten-tensor/src/iterators.rs b/rten-tensor/src/iterators.rs index 6b9d653a..b74bd48b 100644 --- a/rten-tensor/src/iterators.rs +++ b/rten-tensor/src/iterators.rs @@ -2,12 +2,56 @@ use std::iter::{repeat, zip, Cycle, FusedIterator, StepBy, Take}; use std::ops::{Add, Range}; use std::slice; -use super::layout::DynLayout; use super::range::{SliceItem, SliceRange}; use crate::{ to_slice_items, DynIndices, Layout, NdTensorView, NdTensorViewMut, TensorView, TensorViewMut, }; +/// Borrowed reference to a tensor's data and layout. This differs from +/// [TensorView] in that it borrows the layout rather than having its own. +/// +/// `'d` is the lifetime of the data and `'l` the lifetime of the layout. +pub(crate) struct ViewRef<'d, 'l, T, L: Layout> { + data: &'d [T], + layout: &'l L, +} + +impl<'d, 'l, T, L: Layout> ViewRef<'d, 'l, T, L> { + pub(crate) fn new(data: &'d [T], layout: &'l L) -> ViewRef<'d, 'l, T, L> { + ViewRef { data, layout } + } + + fn contiguous_data(&self) -> Option<&'d [T]> { + self.layout.is_contiguous().then_some(self.data) + } + + fn shape(&self) -> L::Index<'_> { + self.layout.shape() + } +} + +impl<'d, 'l, T, L: Layout> Clone for ViewRef<'d, 'l, T, L> { + fn clone(&self) -> ViewRef<'d, 'l, T, L> { + ViewRef { + data: self.data, + layout: self.layout, + } + } +} + +/// Mutably borrowed reference to a tensor's data and layout. This differs from +/// [TensorViewMut] in that it borrows the layout rather than having its own. +pub(crate) struct MutViewRef<'d, 'l, T, L: Layout> { + data: &'d mut [T], + layout: &'l L, +} + +impl<'d, 'l, T, L: Layout> MutViewRef<'d, 'l, T, L> { + pub(crate) fn new(data: &'d mut [T], layout: &'l L) -> MutViewRef<'d, 'l, T, L> { + MutViewRef { data, layout } + } +} + /// IterPos tracks the position within a single dimension of an IndexingIter. #[derive(Copy, Clone, Debug)] struct IterPos { @@ -62,9 +106,10 @@ struct IndexingIterBase { impl IndexingIterBase { /// Create an iterator over element offsets in `tensor`. - fn new(layout: &DynLayout) -> IndexingIterBase { + fn new(layout: &L) -> IndexingIterBase { let dims = layout .shape() + .as_ref() .iter() .enumerate() .map(|(dim, &len)| IterPos::new(len, layout.stride(dim) as isize)) @@ -79,11 +124,13 @@ impl IndexingIterBase { /// Create an iterator over offsets of elements in `tensor`, as if it had /// a given `shape`. This will repeat offsets as necessary. - fn broadcast(layout: &DynLayout, shape: &[usize]) -> IndexingIterBase { + fn broadcast(layout: &L, shape: &[usize]) -> IndexingIterBase { // nb. We require that the broadcast shape has a length >= the actual // shape. - let added_dims = shape.len() - layout.shape().len(); - let padded_tensor_shape = repeat(&0).take(added_dims).chain(layout.shape().iter()); + let added_dims = shape.len() - layout.ndim(); + let layout_shape = layout.shape(); + let layout_shape = layout_shape.as_ref(); + let padded_tensor_shape = repeat(&0).take(added_dims).chain(layout_shape.iter()); let dims = zip(padded_tensor_shape, shape.iter()) .enumerate() .map(|(dim, (&actual_len, &broadcast_len))| { @@ -108,7 +155,7 @@ impl IndexingIterBase { } /// Create an iterator over offsets of a subset of elements in `tensor`. - fn slice(layout: &DynLayout, range: &[SliceItem]) -> IndexingIterBase { + fn slice(layout: &L, range: &[SliceItem]) -> IndexingIterBase { assert!( range.len() == layout.ndim(), "slice dimensions {} do not match tensor dimensions {}", @@ -231,8 +278,8 @@ enum IterKind<'a, T> { } impl<'a, T> Iter<'a, T> { - pub(super) fn new(view: &TensorView<'a, T>) -> Iter<'a, T> { - if let Some(data) = view.data() { + pub(super) fn new(view: ViewRef<'a, '_, T, L>) -> Iter<'a, T> { + if let Some(data) = view.contiguous_data() { Iter { iter: IterKind::Direct(data.iter()), } @@ -243,14 +290,13 @@ impl<'a, T> Iter<'a, T> { } } - pub(super) fn slice(view: &TensorView<'a, T>, range: &[SliceItem]) -> Iter<'a, T> { + pub(super) fn slice( + view: ViewRef<'a, '_, T, L>, + range: &[SliceItem], + ) -> Iter<'a, T> { let iter = IndexingIter { - base: IndexingIterBase::slice(view.layout(), range), - - // Safety: The `Iterator::next` impl must only yield offsets from - // this slice that belong to the tensor slice. This is true of - // the offsets yielded by `IndexingIterBase`. - data: unsafe { view.data_unchecked() }, + base: IndexingIterBase::slice(view.layout, range), + data: view.data, }; Iter { iter: IterKind::Indexing(iter), @@ -300,25 +346,17 @@ struct IndexingIter<'a, T> { } impl<'a, T> IndexingIter<'a, T> { - fn new(view: &TensorView<'a, T>) -> IndexingIter<'a, T> { + fn new(view: ViewRef<'a, '_, T, L>) -> IndexingIter<'a, T> { IndexingIter { - base: IndexingIterBase::new(view.layout()), - - // Safety: The `Iterator::next` impl must only yield offsets from - // this slice that belong to the tensor view. This is true of - // the offsets yielded by `IndexingIterBase`. - data: unsafe { view.data_unchecked() }, + base: IndexingIterBase::new(view.layout), + data: view.data, } } - fn broadcast(view: &TensorView<'a, T>, shape: &[usize]) -> IndexingIter<'a, T> { + fn broadcast(view: ViewRef<'a, '_, T, L>, shape: &[usize]) -> IndexingIter<'a, T> { IndexingIter { - base: IndexingIterBase::broadcast(view.layout(), shape), - - // Safety: The `Iterator::next` impl must only yield offsets from - // this slice that belong to the broadcasted tensor view. This is - // true of the offsets yielded by `IndexingIterBase`. - data: unsafe { view.data_unchecked() }, + base: IndexingIterBase::broadcast(view.layout, shape), + data: view.data, } } } @@ -360,14 +398,14 @@ enum IterMutKind<'a, T> { } impl<'a, T> IterMut<'a, T> { - pub(super) fn new(data: &'a mut [T], layout: &DynLayout) -> IterMut<'a, T> { - if layout.is_contiguous() { + pub(super) fn new(view: MutViewRef<'a, '_, T, L>) -> IterMut<'a, T> { + if view.layout.is_contiguous() { IterMut { - iter: IterMutKind::Direct(data.iter_mut()), + iter: IterMutKind::Direct(view.data.iter_mut()), } } else { IterMut { - iter: IterMutKind::Indexing(IndexingIterMut::new(data, layout)), + iter: IterMutKind::Indexing(IndexingIterMut::new(view)), } } } @@ -415,15 +453,15 @@ struct IndexingIterMut<'a, T> { } impl<'a, T> IndexingIterMut<'a, T> { - fn new(data: &'a mut [T], layout: &DynLayout) -> IndexingIterMut<'a, T> { + fn new(view: MutViewRef<'a, '_, T, L>) -> IndexingIterMut<'a, T> { // See notes in `Layout` about internal overlap. assert!( - !layout.is_broadcast(), + !view.layout.is_broadcast(), "Cannot mutably iterate over broadcasting view" ); IndexingIterMut { - base: IndexingIterBase::new(layout), - data, + base: IndexingIterBase::new(view.layout), + data: view.data, } } } @@ -467,19 +505,19 @@ pub struct Offsets { } impl Offsets { - pub fn new(layout: &DynLayout) -> Offsets { + pub fn new(layout: &L) -> Offsets { Offsets { base: IndexingIterBase::new(layout), } } - pub fn broadcast(layout: &DynLayout, shape: &[usize]) -> Offsets { + pub fn broadcast(layout: &L, shape: &[usize]) -> Offsets { Offsets { base: IndexingIterBase::broadcast(layout, shape), } } - pub fn slice(layout: &DynLayout, range: &[SliceItem]) -> Offsets { + pub fn slice(layout: &L, range: &[SliceItem]) -> Offsets { Offsets { base: IndexingIterBase::slice(layout, range), } @@ -563,16 +601,20 @@ fn can_broadcast_by_cycling(from_shape: &[usize], to_shape: &[usize]) -> bool { } impl<'a, T> BroadcastIter<'a, T> { - pub fn new(view: &TensorView<'a, T>, to_shape: &[usize]) -> BroadcastIter<'a, T> { + pub(crate) fn new( + view: ViewRef<'a, '_, T, L>, + to_shape: &[usize], + ) -> BroadcastIter<'a, T> { + let tmp_view = view.clone(); let iter = match ( - view.data(), - can_broadcast_by_cycling(view.shape(), to_shape), + view.contiguous_data(), + can_broadcast_by_cycling(view.shape().as_ref(), to_shape), ) { (Some(data), true) => { let iter_len = to_shape.iter().product(); BroadcastIterKind::Direct(data.iter().cycle().take(iter_len)) } - _ => BroadcastIterKind::Indexing(IndexingIter::broadcast(view, to_shape)), + _ => BroadcastIterKind::Indexing(IndexingIter::broadcast(tmp_view, to_shape)), }; BroadcastIter { iter } } @@ -771,21 +813,21 @@ struct LaneRanges { } impl LaneRanges { - fn new(tensor: &TensorView, dim: usize) -> LaneRanges { - let slice_starts: Vec = (0..tensor.ndim()) + fn new(layout: &L, dim: usize) -> LaneRanges { + let slice_starts: Vec = (0..layout.ndim()) .map(|i| { if i == dim { (0..1).into() } else { - (0..(tensor.shape()[i] as isize)).into() + (0..(layout.size(i) as isize)).into() } }) .collect(); - let offsets = tensor.slice_offsets(&slice_starts); + let offsets = Offsets::slice(layout, &slice_starts); LaneRanges { offsets, - dim_size: tensor.size(dim), - dim_stride: tensor.stride(dim), + dim_size: layout.size(dim), + dim_stride: layout.stride(dim), } } } @@ -835,12 +877,10 @@ impl<'a, T> ExactSizeIterator for Lane<'a, T> {} impl<'a, T> Lanes<'a, T> { /// Create an iterator which yields all possible slices over the `dim` /// dimension of `tensor`. - pub(crate) fn new(tensor: TensorView<'a, T>, dim: usize) -> Lanes<'a, T> { + pub(crate) fn new(view: ViewRef<'a, '_, T, L>, dim: usize) -> Lanes<'a, T> { Lanes { - // Safety: `Lane`s yielded by this iterator will only yield elements - // that belong to this lane. - data: unsafe { tensor.data_unchecked() }, - ranges: LaneRanges::new(&tensor, dim), + data: view.data, + ranges: LaneRanges::new(view.layout, dim), } } } @@ -865,22 +905,22 @@ impl<'a, T> Iterator for Lanes<'a, T> { /// in implementing this for an iterator that returns mutable references, but /// it has a similar interface. pub struct LanesMut<'a, T> { - tensor: TensorViewMut<'a, T>, + data: &'a mut [T], ranges: LaneRanges, } impl<'a, T> LanesMut<'a, T> { /// Create an iterator which yields all possible slices over the `dim` - /// dimension of `tensor`. - pub(crate) fn new(tensor: TensorViewMut<'a, T>, dim: usize) -> LanesMut<'a, T> { + /// dimension of `view`. + pub(crate) fn new(view: MutViewRef<'a, '_, T, L>, dim: usize) -> LanesMut<'a, T> { // See notes in `Layout` about internal overlap. assert!( - !tensor.layout().is_broadcast(), + !view.layout.is_broadcast(), "Cannot mutably iterate over broadcasting view" ); LanesMut { - ranges: LaneRanges::new(&tensor.view(), dim), - tensor, + ranges: LaneRanges::new(view.layout, dim), + data: view.data, } } } @@ -913,7 +953,7 @@ impl<'a, T> Iterator for LanesMut<'a, T> { // Safety: This is a non-broadcasting view, so each `LaneMut` // yielded by this iterator will yield a distinct set of elements. let slice = unsafe { - let slice = &mut self.tensor.data_mut_unchecked()[range]; + let slice = &mut self.data[range]; std::mem::transmute::<&mut [T], &'a mut [T]>(slice) }; @@ -1011,14 +1051,14 @@ mod tests { #[test] fn test_lanes_empty() { let x = Tensor::::zeros(&[5, 0]); - assert!(Lanes::new(x.view(), 0).next().is_none()); - assert!(Lanes::new(x.view(), 1).next().is_none()); + assert!(Lanes::new(x.view().view_ref(), 0).next().is_none()); + assert!(Lanes::new(x.view().view_ref(), 1).next().is_none()); } #[test] fn test_lanes_mut_empty() { let mut x = Tensor::::zeros(&[5, 0]); - assert!(LanesMut::new(x.view_mut(), 0).next().is_none()); - assert!(LanesMut::new(x.view_mut(), 1).next().is_none()); + assert!(LanesMut::new(x.mut_view_ref(), 0).next().is_none()); + assert!(LanesMut::new(x.mut_view_ref(), 1).next().is_none()); } } diff --git a/rten-tensor/src/layout.rs b/rten-tensor/src/layout.rs index bf9d7d87..8463b70c 100644 --- a/rten-tensor/src/layout.rs +++ b/rten-tensor/src/layout.rs @@ -7,7 +7,6 @@ use crate::errors::{DimensionError, FromDataError, SliceError}; use crate::index_iterator::{DynIndices, NdIndices}; use crate::overlap::{is_contiguous, may_have_internal_overlap}; use crate::range::SliceItem; -use crate::tensor::TensorIndex; /// Return true if `permutation` is a valid permutation of dimensions for /// a tensor of rank `ndim`. @@ -16,19 +15,40 @@ pub fn is_valid_permutation(ndim: usize, permutation: &[usize]) -> bool { && (0..ndim).all(|dim| permutation.iter().filter(|d| **d == dim).count() == 1) } -/// Provides methods for querying the shape and strides of a tensor. +/// Layouts describe the shape of a tensor, ie. the number of dimensions and +/// size of each, and the mapping between indices and offsets in the data +/// storage. +/// +/// The main implementations are [NdLayout], where the dimension count is known +/// statically, and [DynLayout], where the dimension count is only known at +/// runtime. pub trait Layout { /// Type used to represent indices. /// /// It is assumed that this type can also represent the shape and strides /// of the tensor. - type Index<'a>: AsRef<[usize]> + std::fmt::Debug + PartialEq> - where - Self: 'a; + type Index<'a>: AsRef<[usize]> + Clone + std::fmt::Debug + PartialEq>; /// Iterator over indices in this tensor. type Indices; + /// Map an index to a storage offset. + /// + /// Panics if any dimension of the index is out of bounds. + fn offset(&self, index: Self::Index<'_>) -> usize { + self.try_offset(index.clone()).unwrap_or_else(|| { + panic!( + "index {:?} out of bounds for shape {:?}", + index.as_ref(), + self.shape().as_ref() + ); + }) + } + + /// Map an index to a storage offset, or return `None` if the index is out + /// of bounds along any dimension. + fn try_offset(&self, index: Self::Index<'_>) -> Option; + /// Return the number of dimensions. fn ndim(&self) -> usize; @@ -41,6 +61,12 @@ pub trait Layout { is_contiguous(self.shape(), self.strides()) } + /// Return true if iterating over elements in this layout will visit + /// elements multiple times. + fn is_broadcast(&self) -> bool { + !self.is_empty() && self.strides().as_ref().iter().any(|&stride| stride == 0) + } + /// Returns true if the array has no elements. fn is_empty(&self) -> bool { self.len() == 0 @@ -78,7 +104,7 @@ pub trait Layout { // // If the tensor has fewer dimensions, pretend that it was prefixed with // 1-length dimensions to make the dimension counts equal. - let target_dims = target_shape[target_shape.len() - self.shape().len()..] + let target_dims = target_shape[target_shape.len() - self.shape().as_ref().len()..] .iter() .copied(); @@ -157,6 +183,13 @@ impl Layout for NdLayout { self.shape.iter().product() } + fn try_offset(&self, index: [usize; N]) -> Option { + if !self.index_valid(index) { + return None; + } + Some(self.offset_unchecked(index)) + } + #[inline] fn shape(&self) -> Self::Index<'_> { self.shape @@ -302,26 +335,6 @@ impl NdLayout { valid } - /// Return the offset in the slice that an index maps to. - pub fn offset(&self, index: [usize; N]) -> usize { - assert!( - self.index_valid(index), - "Index {:?} out of bounds for shape {:?}", - index, - self.shape - ); - self.offset_unchecked(index) - } - - /// Return the offset in the slice that an index maps to, or `None` if it - /// is out of bounds. - pub fn try_offset(&self, index: [usize; N]) -> Option { - if !self.index_valid(index) { - return None; - } - Some(self.offset_unchecked(index)) - } - /// Return the offset in the slice that an index maps to. /// /// Unlike `offset`, this does not bounds-check elements of `index` against @@ -517,6 +530,19 @@ impl Layout for DynLayout { self.shape().iter().product() } + #[inline] + fn try_offset(&self, index: Self::Index<'_>) -> Option { + let shape = self.shape(); + let strides = self.strides(); + let mut valid = index.as_ref().len() == shape.len(); + let mut offset = 0; + for (idx, (size, stride)) in index.as_ref().iter().zip(shape.iter().zip(strides.iter())) { + valid = valid && idx < size; + offset += idx * stride; + } + valid.then_some(offset) + } + fn is_empty(&self) -> bool { self.len() == 0 } @@ -559,7 +585,7 @@ impl Layout for DynLayout { impl DynLayout { /// Construct a layout with dimension sizes given by `shape` and default /// (contiguous) strides. - pub fn new(shape: &[usize]) -> DynLayout { + pub fn from_shape(shape: &[usize]) -> DynLayout { DynLayout { shape_and_strides: Self::contiguous_shape_and_strides(shape), } @@ -671,12 +697,6 @@ impl DynLayout { self.shape_and_strides[dim] = new_size; } - /// Return true if iterating over elements in this layout will visit - /// elements multiple times. - pub fn is_broadcast(&self) -> bool { - !self.is_empty() && self.strides().iter().any(|&stride| stride == 0) - } - pub fn make_contiguous(&mut self) { self.shape_and_strides = Self::contiguous_shape_and_strides(self.shape()); } @@ -753,7 +773,7 @@ impl DynLayout { self.is_contiguous(), "can only reshape a contiguous tensor/view" ); - *self = DynLayout::new(shape); + *self = DynLayout::from_shape(shape); } pub fn reshaped(&self, shape: &[usize]) -> DynLayout { @@ -762,39 +782,21 @@ impl DynLayout { reshaped } - /// Return the offset in the slice that an index maps to, or `None` if it - /// is out of bounds. - #[inline] - pub fn try_offset(&self, index: Idx) -> Option { - let shape = self.shape(); - let strides = self.strides(); - let mut valid = index.len() == shape.len(); - let mut offset = 0; - for (idx, (size, stride)) in index.iter().zip(shape.iter().zip(strides.iter())) { - valid = valid && idx < size; - offset += idx * stride; - } - valid.then_some(offset) - } - - /// Return the offset of the element with a given index. - pub fn offset(&self, index: Idx) -> usize { - self.try_offset(index).expect("invalid index") - } - /// Return the offset of the slice that begins at the given index. - pub fn slice_offset(&self, index: Idx) -> usize { + pub fn slice_offset>(&self, index: Idx) -> usize { + let index = index.as_ref(); + assert!(index.len() <= self.ndim()); let shape = self.shape(); let mut offset = 0; for i in 0..index.len() { assert!( - index.index(i) < shape[i], + index[i] < shape[i], "Invalid index {} for dim {}", - index.index(i), + index[i], i ); - offset += index.index(i) * self.stride(i) + offset += index[i] * self.stride(i) } offset } @@ -837,6 +839,12 @@ impl From<&NdLayout> for DynLayout { } } +impl From> for DynLayout { + fn from(value: NdLayout) -> DynLayout { + DynLayout::from(&value) + } +} + #[cfg(test)] mod tests { use std::iter::zip; @@ -847,11 +855,11 @@ mod tests { #[test] fn test_is_broadcast() { // Non-empty, contiguous layout - let layout = DynLayout::new(&[5, 5]); + let layout = DynLayout::from_shape(&[5, 5]); assert!(!layout.is_broadcast()); // Empty layout - let layout = DynLayout::new(&[5, 0]); + let layout = DynLayout::from_shape(&[5, 0]); assert!(!layout.is_broadcast()); // Broadcasting layout @@ -894,7 +902,7 @@ mod tests { #[test] fn test_move_axis() { - let mut layout = DynLayout::new(&[2, 4, 8]); + let mut layout = DynLayout::from_shape(&[2, 4, 8]); assert_eq!(layout.strides(), [32, 8, 1]); layout.move_axis(1, 0); @@ -913,41 +921,41 @@ mod tests { #[test] #[should_panic] fn test_move_axis_invalid_from() { - let mut layout = DynLayout::new(&[2, 4, 8]); + let mut layout = DynLayout::from_shape(&[2, 4, 8]); layout.move_axis(3, 0); } #[test] #[should_panic] fn test_move_axis_invalid_to() { - let mut layout = DynLayout::new(&[2, 4, 8]); + let mut layout = DynLayout::from_shape(&[2, 4, 8]); layout.move_axis(0, 3); } #[test] #[should_panic(expected = "permutation is invalid")] fn test_permute_invalid_len() { - let mut layout = DynLayout::new(&[5, 5]); + let mut layout = DynLayout::from_shape(&[5, 5]); layout.permute(&[1, 0, 3]); } #[test] #[should_panic(expected = "permutation is invalid")] fn test_permute_too_few_dims() { - let mut layout = DynLayout::new(&[5, 5]); + let mut layout = DynLayout::from_shape(&[5, 5]); layout.permute(&[1]); } #[test] #[should_panic(expected = "permutation is invalid")] fn test_permute_repeated_dims() { - let mut layout = DynLayout::new(&[5, 5]); + let mut layout = DynLayout::from_shape(&[5, 5]); layout.permute(&[1, 1]); } #[test] fn test_squeezed() { - let layout = DynLayout::new(&[1, 1, 10, 20]); + let layout = DynLayout::from_shape(&[1, 1, 10, 20]); let squeezed = layout.squeezed(); assert_eq!(squeezed.shape(), &[10, 20]); assert_eq!(squeezed.strides(), &[20, 1]); @@ -956,41 +964,41 @@ mod tests { #[test] #[should_panic(expected = "Slice index is invalid for tensor shape")] fn test_slice_invalid_index() { - let layout = DynLayout::new(&[3, 5]); + let layout = DynLayout::from_shape(&[3, 5]); layout.slice(&[SliceItem::Index(4), SliceItem::Index(0)]); } #[test] #[should_panic(expected = "Slice index is invalid for tensor shape")] fn test_slice_invalid_negative_index() { - let layout = DynLayout::new(&[3, 5]); + let layout = DynLayout::from_shape(&[3, 5]); layout.slice(&[SliceItem::Index(-4)]); } #[test] #[should_panic(expected = "Slice range is invalid for tensor shape")] fn test_slice_invalid_range() { - let layout = DynLayout::new(&[3, 5]); + let layout = DynLayout::from_shape(&[3, 5]); layout.slice(&[SliceItem::Range((1..4).into()), SliceItem::Index(0)]); } #[test] #[should_panic(expected = "Slice range is invalid for tensor shape")] fn test_slice_invalid_from_range() { - let layout = DynLayout::new(&[3, 5]); + let layout = DynLayout::from_shape(&[3, 5]); layout.slice(&[SliceItem::Range((4..).into()), SliceItem::Index(0)]); } #[test] #[should_panic(expected = "Cannot slice with negative step")] fn test_slice_negative_step() { - let layout = DynLayout::new(&[3, 5]); + let layout = DynLayout::from_shape(&[3, 5]); layout.slice(&[SliceItem::full_range(), SliceItem::range(0, None, -1)]); } #[test] fn test_size_stride() { - let layout = DynLayout::new(&[10, 20, 30]); + let layout = DynLayout::from_shape(&[10, 20, 30]); for (dim, (&size, &stride)) in zip(layout.shape().iter(), layout.strides().iter()).enumerate() { diff --git a/rten-tensor/src/lib.rs b/rten-tensor/src/lib.rs index c6ceea1b..3a494250 100644 --- a/rten-tensor/src/lib.rs +++ b/rten-tensor/src/lib.rs @@ -49,7 +49,7 @@ pub use iterators::{ AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter, IterMut, Lanes, LanesMut, Offsets, }; -pub use layout::{is_valid_permutation, DynLayout, Layout, MatrixLayout}; +pub use layout::{is_valid_permutation, DynLayout, Layout, MatrixLayout, NdLayout}; pub use ndtensor::{ Matrix, MatrixMut, NdTensor, NdTensorBase, NdTensorView, NdTensorViewMut, NdView, }; diff --git a/rten-tensor/src/ndtensor.rs b/rten-tensor/src/ndtensor.rs index 12bec846..8c0580b3 100644 --- a/rten-tensor/src/ndtensor.rs +++ b/rten-tensor/src/ndtensor.rs @@ -5,7 +5,7 @@ use std::ops::{Index, IndexMut}; use crate::errors::{DimensionError, FromDataError}; use crate::index_iterator::NdIndices; -use crate::iterators::{Iter, IterMut}; +use crate::iterators::{Iter, IterMut, MutViewRef, ViewRef}; use crate::layout::{Layout, MatrixLayout, NdLayout, OverlapPolicy}; use crate::{IntoSliceItems, RandomSource, TensorBase, TensorView, TensorViewMut, View}; @@ -404,7 +404,11 @@ impl<'a, T, const N: usize> NdTensorView<'a, T, N> { } pub fn iter(&self) -> Iter<'a, T> { - Iter::new(&self.as_dyn()) + Iter::new(self.view_ref()) + } + + fn view_ref(&self) -> ViewRef<'a, '_, T, NdLayout> { + ViewRef::new(self.data, &self.layout) } fn broadcast(&self, shape: [usize; M]) -> NdTensorView<'a, T, M> { @@ -541,7 +545,11 @@ impl + AsMut<[T]>, const N: usize> NdTensorBase { /// Return a mutable iterator over elements of this tensor. pub fn iter_mut(&mut self) -> IterMut { - IterMut::new(self.data.as_mut(), &self.layout.as_dyn()) + IterMut::new(self.mut_view_ref()) + } + + fn mut_view_ref(&mut self) -> MutViewRef> { + MutViewRef::new(self.data.as_mut(), &self.layout) } /// Replace elements of this tensor with `f(element)`. @@ -662,7 +670,7 @@ impl + AsMut<[T]>, const N: usize> IndexMut<[usize; N]> for NdT } impl, const N: usize> Layout for NdTensorBase { - type Index<'a> = [usize; N] where S: 'a, T: 'a; + type Index<'a> = [usize; N]; type Indices = NdIndices; fn ndim(&self) -> usize { @@ -673,6 +681,10 @@ impl, const N: usize> Layout for NdTensorBase { self.layout.len() } + fn try_offset(&self, index: [usize; N]) -> Option { + self.layout.try_offset(index) + } + fn is_empty(&self) -> bool { self.layout.is_empty() } @@ -753,13 +765,17 @@ pub struct UncheckedNdTensor, const N: usize> { } impl, const N: usize> Layout for UncheckedNdTensor { - type Index<'a> = [usize; N] where S: 'a, T: 'a; + type Index<'a> = [usize; N]; type Indices = NdIndices; fn ndim(&self) -> usize { N } + fn try_offset(&self, index: [usize; N]) -> Option { + self.base.try_offset(index) + } + fn len(&self) -> usize { self.base.len() } diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index 562c5540..4d04a68f 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -9,53 +9,13 @@ use std::ops::{Index, IndexMut, Range}; use crate::errors::SliceError; use crate::iterators::{ AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, BroadcastIter, InnerIter, InnerIterMut, Iter, - IterMut, Lanes, LanesMut, Offsets, + IterMut, Lanes, LanesMut, MutViewRef, ViewRef, }; use crate::layout::{DynLayout, Layout}; use crate::ndtensor::{NdTensorBase, NdTensorView, NdTensorViewMut}; use crate::range::{IntoSliceItems, SliceItem}; use crate::rng::XorShiftRng; -/// Trait for indexing a `Tensor` -pub trait TensorIndex { - type Iter<'a>: Iterator - where - Self: 'a; - - /// Return the number of dimensions in the index. - fn len(&self) -> usize; - - /// Return true if this index has zero dimensions (ie. is a scalar). - fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Return the index for dimension `dim` - fn index(&self, dim: usize) -> usize; - - /// Return an iterator over sizes of dimensions in this index. - fn iter(&self) -> Self::Iter<'_>; -} - -impl> TensorIndex for Array { - type Iter<'a> = std::slice::Iter<'a, usize> where Self: 'a; - - #[inline] - fn len(&self) -> usize { - self.as_ref().len() - } - - #[inline] - fn index(&self, dim: usize) -> usize { - self.as_ref()[dim] - } - - #[inline] - fn iter(&self) -> Self::Iter<'_> { - self.as_ref().iter() - } -} - /// Multi-dimensional array view with a dynamic dimension count. This trait /// includes operations that are available on tensors that own their data /// ([Tensor]) as well as views ([TensorView], [TensorViewMut]). @@ -121,20 +81,20 @@ pub trait View: Layout { /// Return the element at a given index, or `None` if the index is out of /// bounds in any dimension. #[inline] - fn get(&self, index: I) -> Option<&Self::Elem> { + fn get>(&self, index: I) -> Option<&Self::Elem> { self.view().get(index) } /// Return an iterator over elements of this tensor, in their logical order. fn iter(&self) -> Iter { - Iter::new(&self.view()) + self.view().iter() } /// Return an iterator over all 1D slices ("lanes") along a given axis. /// /// Each slice is an iterator over the elements in that lane. fn lanes(&self, dim: usize) -> Lanes { - Lanes::new(self.view(), dim) + self.view().lanes(dim) } /// Return a copy of this tensor with each element replaced by `f(element)`. @@ -152,12 +112,12 @@ pub trait View: Layout { }; Tensor { data, - layout: DynLayout::new(self.shape().as_ref()), + layout: DynLayout::from_shape(self.shape().as_ref()), element_type: PhantomData, } } - /// Return an `NdTensor` version of this view. + /// Return a view with a static rank. /// /// Panics if the rank of this tensor is not `N`. fn nd_view(&self) -> NdTensorView { @@ -311,7 +271,7 @@ impl> TensorBase { ); TensorBase { data, - layout: DynLayout::new(shape), + layout: DynLayout::from_shape(shape), element_type: PhantomData, } } @@ -365,23 +325,6 @@ impl> TensorBase { pub fn layout(&self) -> &DynLayout { &self.layout } - - /// Return an iterator over offsets of elements in this tensor, in their - /// logical order. - /// - /// See also the notes for `slice_offsets`. - #[cfg(test)] - fn offsets(&self) -> Offsets { - Offsets::new(self.layout()) - } - - /// Return an iterator over offsets of elements in this tensor. - /// - /// Note that the offset order of the returned iterator will become incorrect - /// if the tensor's layout is modified during iteration. - pub(crate) fn slice_offsets(&self, range: &[SliceItem]) -> Offsets { - Offsets::slice(self.layout(), range) - } } /// Specialized versions of the [View] methods for immutable views. @@ -409,7 +352,7 @@ impl<'a, T> TensorView<'a, T> { self.can_broadcast_to(shape), "Cannot broadcast to specified shape" ); - BroadcastIter::new(self, shape) + BroadcastIter::new(self.view_ref(), shape) } pub fn data(&self) -> Option<&'a [T]> { @@ -428,8 +371,8 @@ impl<'a, T> TensorView<'a, T> { } #[inline] - fn get(&self, index: I) -> Option<&'a T> { - let offset = self.layout.try_offset(index)?; + fn get>(&self, index: I) -> Option<&'a T> { + let offset = self.layout.try_offset(index.as_ref())?; Some(&self.data[offset]) } @@ -438,7 +381,11 @@ impl<'a, T> TensorView<'a, T> { } pub fn iter(&self) -> Iter<'a, T> { - Iter::new(self) + Iter::new(self.view_ref()) + } + + pub(crate) fn view_ref(&self) -> ViewRef<'a, '_, T, DynLayout> { + ViewRef::new(self.data, &self.layout) } pub fn item(&self) -> Option<&'a T> { @@ -450,33 +397,14 @@ impl<'a, T> TensorView<'a, T> { } pub fn lanes(&self, dim: usize) -> Lanes<'a, T> { - Lanes::new(self.clone(), dim) + Lanes::new(self.view_ref(), dim) } pub fn nd_view(&self) -> NdTensorView<'a, T, N> { - self.nd_slice([]) - } - - /// Return an N-dimensional view of a slice of this tensor. - /// - /// See notes in [TensorBase::nd_view]. - /// - /// Base specifies zero or more indices to slice the view with, and N - /// is the rank of the returned view. `B + N` must equal `self.ndim()`. - pub fn nd_slice( - &self, - base: [usize; B], - ) -> NdTensorView<'a, T, N> { - assert!(B + N == self.ndim()); - let offset = self.layout.slice_offset(base); - - // Safety: The offset for the slice is valid, and `NdTensorView` will - // only expose elements from `data` that belong to the sliced view. - let data = unsafe { &self.data_unchecked()[offset..] }; - let strides = self.layout.strides()[self.ndim() - N..].try_into().unwrap(); - let shape = self.layout.shape()[self.ndim() - N..].try_into().unwrap(); - - NdTensorView::from_slice(data, shape, Some(strides)).unwrap() + assert!(self.ndim() == N); + let shape: [usize; N] = self.shape().try_into().unwrap(); + let strides: [usize; N] = self.strides().try_into().unwrap(); + NdTensorView::from_slice(self.data, shape, Some(strides)).unwrap() } pub fn permuted(&self, dims: &[usize]) -> TensorView<'a, T> { @@ -517,7 +445,7 @@ impl<'a, T> TensorView<'a, T> { } pub fn slice_iter(&self, range: &[SliceItem]) -> Iter<'a, T> { - Iter::slice(self, range) + Iter::slice(self.view_ref(), range) } pub fn squeezed(&self) -> TensorView<'a, T> { @@ -542,7 +470,7 @@ impl<'a, T> TensorView<'a, T> { let data = self.to_vec(); TensorBase { data: Cow::Owned(data), - layout: DynLayout::new(self.layout().shape()), + layout: DynLayout::from_shape(self.layout().shape()), element_type: PhantomData, } } @@ -558,7 +486,7 @@ impl<'a, T> TensorView<'a, T> { } impl> Layout for TensorBase { - type Index<'a> = ::Index<'a> where S: 'a, T: 'a; + type Index<'a> = ::Index<'a>; type Indices = ::Indices; /// Return the number of dimensions. @@ -566,6 +494,10 @@ impl> Layout for TensorBase { self.layout.ndim() } + fn try_offset(&self, index: &[usize]) -> Option { + self.layout.try_offset(index) + } + /// Returns the number of elements in the array. fn len(&self) -> usize { self.layout.len() @@ -650,10 +582,12 @@ impl> View for TensorBase { } } -impl> Index for TensorBase { +impl, T, S: AsRef<[T]>> Index for TensorBase { type Output = T; + fn index(&self, index: I) -> &Self::Output { - &self.data.as_ref()[self.layout.offset(index)] + let offset = self.layout.offset(index.as_ref()); + &self.data.as_ref()[offset] } } @@ -677,19 +611,13 @@ impl + AsMut<[T]>> TensorBase { self.is_contiguous().then_some(self.data.as_mut()) } - /// Return the element buffer for this tensor as a mutable slice. - /// - /// Unlike [TensorBase::data_mut] this does not check if the data is - /// contigous. If multiple mutable views of a non-contiguous tensor exist, - /// the returned slice may overlap with other mutable slices. - pub(crate) unsafe fn data_mut_unchecked(&mut self) -> &mut [T] { - self.data.as_mut() - } - /// Return a mutable iterator over elements of this view. pub fn iter_mut(&mut self) -> IterMut { - let layout = &self.layout; - IterMut::new(self.data.as_mut(), layout) + IterMut::new(self.mut_view_ref()) + } + + pub(crate) fn mut_view_ref(&mut self) -> MutViewRef { + MutViewRef::new(self.data.as_mut(), &self.layout) } /// Return an iterator over mutable slices of this tensor along a given @@ -705,8 +633,8 @@ impl + AsMut<[T]>> TensorBase { /// Return the element at a given index, or `None` if the index is out of /// bounds in any dimension. #[inline] - pub fn get_mut(&mut self, index: I) -> Option<&mut T> { - let offset = self.layout.try_offset(index)?; + pub fn get_mut>(&mut self, index: I) -> Option<&mut T> { + let offset = self.layout.try_offset(index.as_ref())?; Some(&mut self.data.as_mut()[offset]) } @@ -719,7 +647,7 @@ impl + AsMut<[T]>> TensorBase { /// Return a mutable iterator over all 1D slices of this tensor along a /// given axis. pub fn lanes_mut(&mut self, dim: usize) -> LanesMut { - LanesMut::new(self.view_mut(), dim) + LanesMut::new(self.mut_view_ref(), dim) } /// Replace elements of this tensor with `f(element)`. @@ -803,27 +731,14 @@ impl + AsMut<[T]>> TensorBase { TensorViewMut::new(self.data.as_mut(), &self.layout) } - /// Return an N-dimensional slice of this tensor. - /// - /// This is the same as [TensorBase::nd_slice] except that the - /// returned view can be used to modify elements. - pub fn nd_slice_mut( - &mut self, - base: [usize; B], - ) -> NdTensorViewMut { - assert!(B + N == self.ndim()); - let offset = self.layout.slice_offset(base); - let strides = self.layout.strides()[self.ndim() - N..].try_into().unwrap(); - let shape = self.layout.shape()[self.ndim() - N..].try_into().unwrap(); - let data = &mut self.data.as_mut()[offset..]; - NdTensorViewMut::from_data(data, shape, Some(strides)).unwrap() - } - - /// Return a mutable N-dimensional view of this tensor. + /// Return a mutable view with a static rank. /// - /// See notes in `[TensorBase::nd_view]`. + /// Panics if the rank of this tensor is not `N`. pub fn nd_view_mut(&mut self) -> NdTensorViewMut { - self.nd_slice_mut([]) + assert!(self.ndim() == N); + let shape: [usize; N] = self.shape().try_into().unwrap(); + let strides: [usize; N] = self.strides().try_into().unwrap(); + NdTensorViewMut::from_data(self.data.as_mut(), shape, Some(strides)).unwrap() } } @@ -837,9 +752,9 @@ impl<'a, T> TensorViewMut<'a, T> { } } -impl + AsMut<[T]>> IndexMut for TensorBase { +impl, T, S: AsRef<[T]> + AsMut<[T]>> IndexMut for TensorBase { fn index_mut(&mut self, index: I) -> &mut Self::Output { - let offset = self.layout.offset(index); + let offset = self.layout.offset(index.as_ref()); &mut self.data.as_mut()[offset] } } @@ -858,7 +773,7 @@ impl Tensor { let data = vec![T::default(); n_elts]; Tensor { data, - layout: DynLayout::new(shape), + layout: DynLayout::from_shape(shape), element_type: PhantomData, } } @@ -872,7 +787,7 @@ impl Tensor { let data = vec![value; n_elts]; Tensor { data, - layout: DynLayout::new(shape), + layout: DynLayout::from_shape(shape), element_type: PhantomData, } } @@ -916,7 +831,7 @@ impl Tensor { /// Clone this tensor with a new shape. The new shape must have the same /// total number of elements as the existing shape. See `reshape`. - pub fn clone_with_shape(&self, shape: &[usize]) -> Tensor + pub fn to_shape(&self, shape: &[usize]) -> Tensor where T: Clone, { @@ -982,7 +897,7 @@ impl Tensor { // However there are cases of custom strides where copies could be // avoided. See https://pytorch.org/docs/stable/generated/torch.Tensor.view.html. self.make_contiguous(); - self.layout = DynLayout::new(shape); + self.layout = DynLayout::from_shape(shape); } /// Like [Tensor::reshape] but consumes self. @@ -1154,6 +1069,7 @@ fn fast_for_each_element(mut x: TensorView, mut f: F) { mod tests { use std::ops::IndexMut; + use crate::iterators::Offsets; use crate::rng::XorShiftRng; use crate::tensor; use crate::{ @@ -1319,8 +1235,11 @@ mod tests { // Offsets should be relative to the sliced returned by `data`, // `data_mut`. - assert_eq!(x.offsets().collect::>(), &[0, 1, 2, 3, 4, 5]); - assert_eq!(x.layout().offset([0, 0]), 0); + assert_eq!( + Offsets::new(&x).collect::>(), + &[0, 1, 2, 3, 4, 5] + ); + assert_eq!(x.layout().offset(&[0, 0]), 0); } #[test] @@ -1559,7 +1478,7 @@ mod tests { fn test_partial_eq() { let x = tensor!([1, 2, 3, 4, 5]); let y = x.clone(); - let z = x.clone_with_shape(&[1, 5]); + let z = x.to_shape(&[1, 5]); // Int tensors are equal if they have the same shape and elements. assert_eq!(&x, &y); @@ -1751,47 +1670,33 @@ mod tests { } #[test] - fn test_clone_with_shape() { + fn test_to_shape() { let mut rng = XorShiftRng::new(1234); let x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - let y = x.clone_with_shape(&[10, 5, 3 * 7]); + let y = x.to_shape(&[10, 5, 3 * 7]); assert_eq!(y.shape(), &[10, 5, 3 * 7]); assert_eq!(y.data(), x.data()); } #[test] - fn test_nd_slice() { + fn test_nd_view() { let mut rng = XorShiftRng::new(1234); let x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - let x_view = x.view().nd_slice([5, 3]); - - for a in 0..x.size(2) { - for b in 0..x.size(3) { - assert_eq!(x[[5, 3, a, b]], x_view[[a, b]]); - } - } + let x_view = x.nd_view::<4>(); + assert_eq!(x_view.shape(), x.shape()); + assert_eq!(x_view.strides(), x.strides()); + assert_eq!(x_view.data(), x.data()); } #[test] - fn test_nd_slice_mut() { + fn test_nd_view_mut() { let mut rng = XorShiftRng::new(1234); let mut x = Tensor::rand(&[10, 5, 3, 7], &mut rng); - - let [_, _, a_size, b_size]: [usize; 4] = x.shape().try_into().unwrap(); - let mut x_view = x.nd_slice_mut([5, 3]); - - for a in 0..a_size { - for b in 0..b_size { - x_view[[a, b]] = (a + b) as f32; - } - } - - for a in 0..x.size(2) { - for b in 0..x.size(3) { - assert_eq!(x[[5, 3, a, b]], (a + b) as f32); - } - } + let layout = x.layout().clone(); + let x_view = x.nd_view_mut::<4>(); + assert_eq!(x_view.shape(), layout.shape()); + assert_eq!(x_view.strides(), layout.strides()); } #[test] @@ -1955,7 +1860,7 @@ mod tests { let x_elts: Vec<_> = x.to_vec(); - let x_offsets = x.offsets(); + let x_offsets = Offsets::new(&x); let x_data = x.data_mut().unwrap(); let x_elts_from_offset: Vec<_> = x_offsets.map(|off| x_data[off]).collect(); @@ -1965,14 +1870,14 @@ mod tests { #[test] fn test_offsets_nth() { let x = steps(&[3]); - let mut iter = x.offsets(); + let mut iter = Offsets::new(&x); assert_eq!(iter.nth(0), Some(0)); assert_eq!(iter.nth(0), Some(1)); assert_eq!(iter.nth(0), Some(2)); assert_eq!(iter.nth(0), None); let x = steps(&[10]); - let mut iter = x.offsets(); + let mut iter = Offsets::new(&x); assert_eq!(iter.nth(1), Some(1)); assert_eq!(iter.nth(5), Some(7)); assert_eq!(iter.nth(1), Some(9)); @@ -2459,28 +2364,6 @@ mod tests { } } - // These tests assume the correctness of `slice_iter`, given the tests - // above, and check for consistency between the results of `slice_offsets` - // and `slice_iter`. - #[test] - fn test_slice_offsets() { - let x = steps(&[5, 5]); - - // Range that removes the start and end of each dimension. - let range = &[ - SliceItem::range(1, Some(4), 1), - SliceItem::range(1, Some(4), 1), - ]; - let expected: Vec<_> = x.slice_iter(range).copied().collect(); - let x_data = x.data().unwrap(); - let result: Vec<_> = x - .slice_offsets(range) - .map(|offset| x_data[offset]) - .collect(); - - assert_eq!(&result, &expected); - } - #[test] fn test_squeezed() { let mut rng = XorShiftRng::new(1234); diff --git a/src/ops/conv.rs b/src/ops/conv.rs index cb0c9edc..1a2ba0ae 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -646,7 +646,7 @@ pub fn conv_transpose( ); col2im( - &mut output.nd_slice_mut([n]), + &mut output.nd_view_mut::<4>().slice_mut([n]), &col2im_mat .nd_view::<2>() .reshaped([in_h, in_w, out_c, k_h, k_w]), diff --git a/src/ops/layout.rs b/src/ops/layout.rs index c64ecb1f..0e2b7c9c 100644 --- a/src/ops/layout.rs +++ b/src/ops/layout.rs @@ -607,7 +607,7 @@ mod tests { // Reshape with an unspecified (-1) dim and nonzero-length input let input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]); let shape = ndtensor!([1, -1, 2]); - let expected = input.clone_with_shape(&[1, 2, 2]); + let expected = input.to_shape(&[1, 2, 2]); let result = reshape(input.view(), &shape.view(), false /* allow_zero */).unwrap(); expect_equal(&result, &expected)?; @@ -620,7 +620,7 @@ mod tests { false, /* allow_zero */ ) .unwrap(); - let expected = zero_sized_input.clone_with_shape(&[100, 0]); + let expected = zero_sized_input.to_shape(&[100, 0]); expect_equal(&result, &expected)?; Ok(()) @@ -632,14 +632,14 @@ mod tests { // size should be copied. let input = Tensor::from_data(&[1, 1, 4], vec![-0.5, 0.5, 3.0, -5.5]); let shape = ndtensor!([-1, 0]); - let expected = input.clone_with_shape(&[4, 1]); + let expected = input.to_shape(&[4, 1]); let result = reshape(input.view(), &shape.view(), false /* allow_zero */).unwrap(); expect_equal(&result, &expected)?; // Case where copied input dim is also zero. let input = Tensor::::from_data(&[0], vec![]); let shape = ndtensor!([0]); - let expected = input.clone_with_shape(&[0]); + let expected = input.to_shape(&[0]); let result = reshape(input.view(), &shape.view(), false /* allow_zero */).unwrap(); expect_equal(&result, &expected)?; @@ -658,7 +658,7 @@ mod tests { let input = Tensor::::from_data(&[0, 0, 10], vec![]); let shape = ndtensor!([10, 0, 0]); let result = reshape(input.view(), &shape.view(), true /* allow_zero */).unwrap(); - let expected = input.clone_with_shape(&[10, 0, 0]); + let expected = input.to_shape(&[10, 0, 0]); expect_equal(&result, &expected)?; Ok(()) @@ -698,7 +698,7 @@ mod tests { fn test_reshape_in_place() { let mut input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]); let shape = ndtensor!([4]); - let expected = input.clone_with_shape(&[4]); + let expected = input.to_shape(&[4]); reshape_in_place(&mut input, &shape.view(), false /* allow_zero */).unwrap(); assert_eq!(&input, &expected); } @@ -707,7 +707,7 @@ mod tests { fn test_reshape_op() -> Result<(), Box> { let input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]); let shape = Tensor::from_data(&[1], vec![4]); - let expected = input.clone_with_shape(&[4]); + let expected = input.to_shape(&[4]); let op = Reshape { allow_zero: false }; let result = op diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 6f9234ad..573d2cc6 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -146,13 +146,17 @@ impl<'a> Input<'a> { } impl<'a> Layout for Input<'a> { - type Index<'b> = ::Index<'b> where Self: 'b; + type Index<'b> = ::Index<'b>; type Indices = ::Indices; fn ndim(&self) -> usize { self.layout().ndim() } + fn try_offset(&self, index: Self::Index<'_>) -> Option { + self.layout().try_offset(index) + } + fn len(&self) -> usize { self.layout().len() } @@ -312,6 +316,10 @@ impl Layout for Output { self.layout().ndim() } + fn try_offset(&self, index: Self::Index<'_>) -> Option { + self.layout().try_offset(index) + } + fn len(&self) -> usize { self.layout().len() } diff --git a/src/ops/norm.rs b/src/ops/norm.rs index 8a39b52b..65a7e6a0 100644 --- a/src/ops/norm.rs +++ b/src/ops/norm.rs @@ -20,6 +20,8 @@ pub fn batch_norm_in_place( epsilon: f32, ) -> Result<(), OpError> { let [batch, chans, in_h, in_w] = check_dims!(input, 4, "NCHW"); + let mut input = input.nd_view_mut::<4>(); + for n in 0..batch { for c in 0..chans { let chan_mean = mean[[c]]; @@ -27,7 +29,7 @@ pub fn batch_norm_in_place( let chan_scale = scale[[c]]; let chan_bias = bias[[c]]; - let mut out_view = input.nd_slice_mut([n, c]); + let mut out_view = input.slice_mut([n, c]); let mut out_view = out_view.unchecked_mut(); // The batch norm formula, from the ONNX spec, is: diff --git a/src/ops/pooling.rs b/src/ops/pooling.rs index 554d0018..249770af 100644 --- a/src/ops/pooling.rs +++ b/src/ops/pooling.rs @@ -2,7 +2,7 @@ use std::iter::zip; use rayon::prelude::*; use rten_tensor::prelude::*; -use rten_tensor::{NdTensorView, NdTensorViewMut, Tensor, TensorView}; +use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView}; use crate::check_dims; use crate::gemm::div_ceil; @@ -104,14 +104,14 @@ pub fn average_pool( let [kernel_h, kernel_w] = kernel_size; let [stride_h, stride_w] = strides; - let mut output = Tensor::zeros(&[batch, in_c, out_h, out_w]); - let input = input.view(); + let mut output = NdTensor::zeros([batch, in_c, out_h, out_w]); + let input = input.nd_view::<4>(); for n in 0..batch { for chan in 0..in_c { - let mut out_view = output.nd_slice_mut([n, chan]); + let mut out_view = output.slice_mut([n, chan]); let mut out_view = out_view.unchecked_mut(); - let in_view = input.nd_slice([n, chan]).unchecked(); + let in_view = input.slice([n, chan]).unchecked(); for out_y in 0..out_h { for out_x in 0..out_w { @@ -140,7 +140,7 @@ pub fn average_pool( } } - Ok(output) + Ok(output.into_dyn()) } #[derive(Debug)] diff --git a/src/ops/reduce.rs b/src/ops/reduce.rs index a490c6d4..df34d9ec 100644 --- a/src/ops/reduce.rs +++ b/src/ops/reduce.rs @@ -849,7 +849,7 @@ mod tests { expect_equal(&result, &expected)?; let result = reduce_l2(input.view(), Some(&[2]), true /* keep_dims */).unwrap(); - let expected = expected.clone_with_shape(&[3, 2, 1]); + let expected = expected.to_shape(&[3, 2, 1]); expect_equal(&result, &expected)?; Ok(())