Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make various iterators generic over the tensor layout #40

Merged
merged 11 commits into from
Jan 14, 2024
172 changes: 106 additions & 66 deletions rten-tensor/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,56 @@ use std::iter::{repeat, zip, Cycle, FusedIterator, StepBy, Take};
use std::ops::{Add, Range};
use std::slice;

use super::layout::DynLayout;
use super::range::{SliceItem, SliceRange};
use crate::{
to_slice_items, DynIndices, Layout, NdTensorView, NdTensorViewMut, TensorView, TensorViewMut,
};

/// Borrowed reference to a tensor's data and layout. This differs from
/// [TensorView] in that it borrows the layout rather than having its own.
///
/// `'d` is the lifetime of the data and `'l` the lifetime of the layout.
pub(crate) struct ViewRef<'d, 'l, T, L: Layout> {
data: &'d [T],
layout: &'l L,
}

impl<'d, 'l, T, L: Layout> ViewRef<'d, 'l, T, L> {
pub(crate) fn new(data: &'d [T], layout: &'l L) -> ViewRef<'d, 'l, T, L> {
ViewRef { data, layout }
}

fn contiguous_data(&self) -> Option<&'d [T]> {
self.layout.is_contiguous().then_some(self.data)
}

fn shape(&self) -> L::Index<'_> {
self.layout.shape()
}
}

impl<'d, 'l, T, L: Layout> Clone for ViewRef<'d, 'l, T, L> {
fn clone(&self) -> ViewRef<'d, 'l, T, L> {
ViewRef {
data: self.data,
layout: self.layout,
}
}
}

/// Mutably borrowed reference to a tensor's data and layout. This differs from
/// [TensorViewMut] in that it borrows the layout rather than having its own.
pub(crate) struct MutViewRef<'d, 'l, T, L: Layout> {
data: &'d mut [T],
layout: &'l L,
}

impl<'d, 'l, T, L: Layout> MutViewRef<'d, 'l, T, L> {
pub(crate) fn new(data: &'d mut [T], layout: &'l L) -> MutViewRef<'d, 'l, T, L> {
MutViewRef { data, layout }
}
}

/// IterPos tracks the position within a single dimension of an IndexingIter.
#[derive(Copy, Clone, Debug)]
struct IterPos {
Expand Down Expand Up @@ -62,9 +106,10 @@ struct IndexingIterBase {

impl IndexingIterBase {
/// Create an iterator over element offsets in `tensor`.
fn new(layout: &DynLayout) -> IndexingIterBase {
fn new<L: Layout>(layout: &L) -> IndexingIterBase {
let dims = layout
.shape()
.as_ref()
.iter()
.enumerate()
.map(|(dim, &len)| IterPos::new(len, layout.stride(dim) as isize))
Expand All @@ -79,11 +124,13 @@ impl IndexingIterBase {

/// Create an iterator over offsets of elements in `tensor`, as if it had
/// a given `shape`. This will repeat offsets as necessary.
fn broadcast(layout: &DynLayout, shape: &[usize]) -> IndexingIterBase {
fn broadcast<L: Layout>(layout: &L, shape: &[usize]) -> IndexingIterBase {
// nb. We require that the broadcast shape has a length >= the actual
// shape.
let added_dims = shape.len() - layout.shape().len();
let padded_tensor_shape = repeat(&0).take(added_dims).chain(layout.shape().iter());
let added_dims = shape.len() - layout.ndim();
let layout_shape = layout.shape();
let layout_shape = layout_shape.as_ref();
let padded_tensor_shape = repeat(&0).take(added_dims).chain(layout_shape.iter());
let dims = zip(padded_tensor_shape, shape.iter())
.enumerate()
.map(|(dim, (&actual_len, &broadcast_len))| {
Expand All @@ -108,7 +155,7 @@ impl IndexingIterBase {
}

/// Create an iterator over offsets of a subset of elements in `tensor`.
fn slice(layout: &DynLayout, range: &[SliceItem]) -> IndexingIterBase {
fn slice<L: Layout>(layout: &L, range: &[SliceItem]) -> IndexingIterBase {
assert!(
range.len() == layout.ndim(),
"slice dimensions {} do not match tensor dimensions {}",
Expand Down Expand Up @@ -231,8 +278,8 @@ enum IterKind<'a, T> {
}

impl<'a, T> Iter<'a, T> {
pub(super) fn new(view: &TensorView<'a, T>) -> Iter<'a, T> {
if let Some(data) = view.data() {
pub(super) fn new<L: Layout>(view: ViewRef<'a, '_, T, L>) -> Iter<'a, T> {
if let Some(data) = view.contiguous_data() {
Iter {
iter: IterKind::Direct(data.iter()),
}
Expand All @@ -243,14 +290,13 @@ impl<'a, T> Iter<'a, T> {
}
}

pub(super) fn slice(view: &TensorView<'a, T>, range: &[SliceItem]) -> Iter<'a, T> {
pub(super) fn slice<L: Layout>(
view: ViewRef<'a, '_, T, L>,
range: &[SliceItem],
) -> Iter<'a, T> {
let iter = IndexingIter {
base: IndexingIterBase::slice(view.layout(), range),

// Safety: The `Iterator::next` impl must only yield offsets from
// this slice that belong to the tensor slice. This is true of
// the offsets yielded by `IndexingIterBase`.
data: unsafe { view.data_unchecked() },
base: IndexingIterBase::slice(view.layout, range),
data: view.data,
};
Iter {
iter: IterKind::Indexing(iter),
Expand Down Expand Up @@ -300,25 +346,17 @@ struct IndexingIter<'a, T> {
}

impl<'a, T> IndexingIter<'a, T> {
fn new(view: &TensorView<'a, T>) -> IndexingIter<'a, T> {
fn new<L: Layout>(view: ViewRef<'a, '_, T, L>) -> IndexingIter<'a, T> {
IndexingIter {
base: IndexingIterBase::new(view.layout()),

// Safety: The `Iterator::next` impl must only yield offsets from
// this slice that belong to the tensor view. This is true of
// the offsets yielded by `IndexingIterBase`.
data: unsafe { view.data_unchecked() },
base: IndexingIterBase::new(view.layout),
data: view.data,
}
}

fn broadcast(view: &TensorView<'a, T>, shape: &[usize]) -> IndexingIter<'a, T> {
fn broadcast<L: Layout>(view: ViewRef<'a, '_, T, L>, shape: &[usize]) -> IndexingIter<'a, T> {
IndexingIter {
base: IndexingIterBase::broadcast(view.layout(), shape),

// Safety: The `Iterator::next` impl must only yield offsets from
// this slice that belong to the broadcasted tensor view. This is
// true of the offsets yielded by `IndexingIterBase`.
data: unsafe { view.data_unchecked() },
base: IndexingIterBase::broadcast(view.layout, shape),
data: view.data,
}
}
}
Expand Down Expand Up @@ -360,14 +398,14 @@ enum IterMutKind<'a, T> {
}

impl<'a, T> IterMut<'a, T> {
pub(super) fn new(data: &'a mut [T], layout: &DynLayout) -> IterMut<'a, T> {
if layout.is_contiguous() {
pub(super) fn new<L: Layout>(view: MutViewRef<'a, '_, T, L>) -> IterMut<'a, T> {
if view.layout.is_contiguous() {
IterMut {
iter: IterMutKind::Direct(data.iter_mut()),
iter: IterMutKind::Direct(view.data.iter_mut()),
}
} else {
IterMut {
iter: IterMutKind::Indexing(IndexingIterMut::new(data, layout)),
iter: IterMutKind::Indexing(IndexingIterMut::new(view)),
}
}
}
Expand Down Expand Up @@ -415,15 +453,15 @@ struct IndexingIterMut<'a, T> {
}

impl<'a, T> IndexingIterMut<'a, T> {
fn new(data: &'a mut [T], layout: &DynLayout) -> IndexingIterMut<'a, T> {
fn new<L: Layout>(view: MutViewRef<'a, '_, T, L>) -> IndexingIterMut<'a, T> {
// See notes in `Layout` about internal overlap.
assert!(
!layout.is_broadcast(),
!view.layout.is_broadcast(),
"Cannot mutably iterate over broadcasting view"
);
IndexingIterMut {
base: IndexingIterBase::new(layout),
data,
base: IndexingIterBase::new(view.layout),
data: view.data,
}
}
}
Expand Down Expand Up @@ -467,19 +505,19 @@ pub struct Offsets {
}

impl Offsets {
pub fn new(layout: &DynLayout) -> Offsets {
pub fn new<L: Layout>(layout: &L) -> Offsets {
Offsets {
base: IndexingIterBase::new(layout),
}
}

pub fn broadcast(layout: &DynLayout, shape: &[usize]) -> Offsets {
pub fn broadcast<L: Layout>(layout: &L, shape: &[usize]) -> Offsets {
Offsets {
base: IndexingIterBase::broadcast(layout, shape),
}
}

pub fn slice(layout: &DynLayout, range: &[SliceItem]) -> Offsets {
pub fn slice<L: Layout>(layout: &L, range: &[SliceItem]) -> Offsets {
Offsets {
base: IndexingIterBase::slice(layout, range),
}
Expand Down Expand Up @@ -563,16 +601,20 @@ fn can_broadcast_by_cycling(from_shape: &[usize], to_shape: &[usize]) -> bool {
}

impl<'a, T> BroadcastIter<'a, T> {
pub fn new(view: &TensorView<'a, T>, to_shape: &[usize]) -> BroadcastIter<'a, T> {
pub(crate) fn new<L: Layout>(
view: ViewRef<'a, '_, T, L>,
to_shape: &[usize],
) -> BroadcastIter<'a, T> {
let tmp_view = view.clone();
let iter = match (
view.data(),
can_broadcast_by_cycling(view.shape(), to_shape),
view.contiguous_data(),
can_broadcast_by_cycling(view.shape().as_ref(), to_shape),
) {
(Some(data), true) => {
let iter_len = to_shape.iter().product();
BroadcastIterKind::Direct(data.iter().cycle().take(iter_len))
}
_ => BroadcastIterKind::Indexing(IndexingIter::broadcast(view, to_shape)),
_ => BroadcastIterKind::Indexing(IndexingIter::broadcast(tmp_view, to_shape)),
};
BroadcastIter { iter }
}
Expand Down Expand Up @@ -771,21 +813,21 @@ struct LaneRanges {
}

impl LaneRanges {
fn new<T>(tensor: &TensorView<T>, dim: usize) -> LaneRanges {
let slice_starts: Vec<SliceItem> = (0..tensor.ndim())
fn new<L: Layout>(layout: &L, dim: usize) -> LaneRanges {
let slice_starts: Vec<SliceItem> = (0..layout.ndim())
.map(|i| {
if i == dim {
(0..1).into()
} else {
(0..(tensor.shape()[i] as isize)).into()
(0..(layout.size(i) as isize)).into()
}
})
.collect();
let offsets = tensor.slice_offsets(&slice_starts);
let offsets = Offsets::slice(layout, &slice_starts);
LaneRanges {
offsets,
dim_size: tensor.size(dim),
dim_stride: tensor.stride(dim),
dim_size: layout.size(dim),
dim_stride: layout.stride(dim),
}
}
}
Expand Down Expand Up @@ -835,12 +877,10 @@ impl<'a, T> ExactSizeIterator for Lane<'a, T> {}
impl<'a, T> Lanes<'a, T> {
/// Create an iterator which yields all possible slices over the `dim`
/// dimension of `tensor`.
pub(crate) fn new(tensor: TensorView<'a, T>, dim: usize) -> Lanes<'a, T> {
pub(crate) fn new<L: Layout>(view: ViewRef<'a, '_, T, L>, dim: usize) -> Lanes<'a, T> {
Lanes {
// Safety: `Lane`s yielded by this iterator will only yield elements
// that belong to this lane.
data: unsafe { tensor.data_unchecked() },
ranges: LaneRanges::new(&tensor, dim),
data: view.data,
ranges: LaneRanges::new(view.layout, dim),
}
}
}
Expand All @@ -865,22 +905,22 @@ impl<'a, T> Iterator for Lanes<'a, T> {
/// in implementing this for an iterator that returns mutable references, but
/// it has a similar interface.
pub struct LanesMut<'a, T> {
tensor: TensorViewMut<'a, T>,
data: &'a mut [T],
ranges: LaneRanges,
}

impl<'a, T> LanesMut<'a, T> {
/// Create an iterator which yields all possible slices over the `dim`
/// dimension of `tensor`.
pub(crate) fn new(tensor: TensorViewMut<'a, T>, dim: usize) -> LanesMut<'a, T> {
/// dimension of `view`.
pub(crate) fn new<L: Layout>(view: MutViewRef<'a, '_, T, L>, dim: usize) -> LanesMut<'a, T> {
// See notes in `Layout` about internal overlap.
assert!(
!tensor.layout().is_broadcast(),
!view.layout.is_broadcast(),
"Cannot mutably iterate over broadcasting view"
);
LanesMut {
ranges: LaneRanges::new(&tensor.view(), dim),
tensor,
ranges: LaneRanges::new(view.layout, dim),
data: view.data,
}
}
}
Expand Down Expand Up @@ -913,7 +953,7 @@ impl<'a, T> Iterator for LanesMut<'a, T> {
// Safety: This is a non-broadcasting view, so each `LaneMut`
// yielded by this iterator will yield a distinct set of elements.
let slice = unsafe {
let slice = &mut self.tensor.data_mut_unchecked()[range];
let slice = &mut self.data[range];
std::mem::transmute::<&mut [T], &'a mut [T]>(slice)
};

Expand Down Expand Up @@ -1011,14 +1051,14 @@ mod tests {
#[test]
fn test_lanes_empty() {
let x = Tensor::<i32>::zeros(&[5, 0]);
assert!(Lanes::new(x.view(), 0).next().is_none());
assert!(Lanes::new(x.view(), 1).next().is_none());
assert!(Lanes::new(x.view().view_ref(), 0).next().is_none());
assert!(Lanes::new(x.view().view_ref(), 1).next().is_none());
}

#[test]
fn test_lanes_mut_empty() {
let mut x = Tensor::<i32>::zeros(&[5, 0]);
assert!(LanesMut::new(x.view_mut(), 0).next().is_none());
assert!(LanesMut::new(x.view_mut(), 1).next().is_none());
assert!(LanesMut::new(x.mut_view_ref(), 0).next().is_none());
assert!(LanesMut::new(x.mut_view_ref(), 1).next().is_none());
}
}
Loading