Skip to content

Commit

Permalink
Merge pull request #568 from robertknight/cow-type-alias
Browse files Browse the repository at this point in the history
Add `CowTensor`, `CowNdTensor` type aliases
  • Loading branch information
robertknight authored Feb 2, 2025
2 parents 36d7986 + 5082a7c commit 9d4a9cf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
38 changes: 22 additions & 16 deletions rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
//! rten_tensor provides multi-dimensional arrays, commonly referred to as
//! _tensors_ in a machine learning context.
//!
//! Each tensor is a combination of data and a layout. The data can be owned,
//! borrowed or mutably borrowed. This is analagous to `Vec<T>`, `&[T]` and
//! `&mut [T]` for 1D arrays. The layout determines the number of dimensions
//! (the _rank_), the size of each dimension, and the strides (gap between
//! successive indices along a given dimension).
//! # Storage and layout
//!
//! # Key types and traits
//! A tensor is a combination of data storage and a layout. The data storage
//! determines the element type and how the data is owned. A tensor can be:
//!
//! - Owned (like `Vec<T>`)
//! - Borrowed (like `&[T]` or `&mut [T]`)
//! - Maybe-owned (like `Cow[T]`)
//!
//! The layout determines the number of dimensions (the _rank_) and size of each
//! (the _shape_) and how indices map to offsets in the storage. The dimension
//! count can be static (fixed at compile time) or dynamic (variable at
//! runtime).
//!
//! # Tensor types and traits
//!
//! The base type for all tensors is [TensorBase]. This is not normally used
//! directly but instead via a type alias, depending on whether the number of
//! dimensions (the _rank_) of the tensor is known at compile time or only
//! at runtime, as well as whether the tensor owns, borrows or mutably borrows
//! its data.
//! directly but instead via a type alias which specifies the data ownership
//! and layout:
//!
//! | Rank | Owned (like `Vec<T>`) | Borrowed (like `&[T]`) | Mutably borrowed |
//! | ---- | --------------------- | ---------------------- | ---------------- |
//! | Static | [NdTensor] | [NdTensorView] | [NdTensorViewMut] |
//! | Dynamic | [Tensor] | [TensorView] | [TensorViewMut] |
//! | Rank | Owned | Borrowed | Mutably borrowed | Owned or borrowed |
//! | ---- | ----- | -------- | ---------------- | ----------------- |
//! | Static | [NdTensor] | [NdTensorView] | [NdTensorViewMut] | [CowNdTensor] |
//! | Dynamic | [Tensor] | [TensorView] | [TensorViewMut] | [CowTensor] |
//!
//! All tensors implement the [Layout] trait, which provide methods to query
//! the shape, dimension count and strides of the tensor. Tensor views provide
Expand Down Expand Up @@ -125,8 +131,8 @@ pub use layout::{
pub use slice_range::{to_slice_items, DynSliceItems, IntoSliceItems, SliceItem, SliceRange};

pub use tensor::{
AsView, Matrix, MatrixMut, NdTensor, NdTensorView, NdTensorViewMut, Scalar, Tensor, TensorBase,
TensorView, TensorViewMut, WeaklyCheckedView,
AsView, CowNdTensor, CowTensor, Matrix, MatrixMut, NdTensor, NdTensorView, NdTensorViewMut,
Scalar, Tensor, TensorBase, TensorView, TensorViewMut, WeaklyCheckedView,
};

pub use storage::{CowData, IntoStorage, Storage, StorageMut, ViewData, ViewMutData};
Expand Down
16 changes: 16 additions & 0 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,14 @@ pub type NdTensor<T, const N: usize> = TensorBase<Vec<T>, NdLayout<N>>;
/// Mutable view of a tensor with N dimensions.
pub type NdTensorViewMut<'a, T, const N: usize> = TensorBase<ViewMutData<'a, T>, NdLayout<N>>;

/// Owned or borrowed tensor with N dimensions.
///
/// `CowNdTensor`s can be created using [`as_cow`](TensorBase::as_cow) (to
/// borrow) or [`into_cow`](TensorBase::into_cow).
///
/// The name comes from [`std::borrow::Cow`].
pub type CowNdTensor<'a, T, const N: usize> = TensorBase<CowData<'a, T>, NdLayout<N>>;

/// View of a 2D tensor.
pub type Matrix<'a, T = f32> = NdTensorView<'a, T, 2>;

Expand All @@ -2101,6 +2109,14 @@ pub type TensorView<'a, T = f32> = TensorBase<ViewData<'a, T>, DynLayout>;
/// Mutable view of a tensor with a dynamic dimension count.
pub type TensorViewMut<'a, T = f32> = TensorBase<ViewMutData<'a, T>, DynLayout>;

/// Owned or borrowed tensor with a dynamic dimension count.
///
/// `CowTensor`s can be created using [`as_cow`](TensorBase::as_cow) (to
/// borrow) or [`into_cow`](TensorBase::into_cow).
///
/// The name comes from [`std::borrow::Cow`].
pub type CowTensor<'a, T> = TensorBase<CowData<'a, T>, DynLayout>;

impl<T, S: Storage<Elem = T>, L: MutLayout, I: AsIndex<L>> Index<I> for TensorBase<S, L> {
type Output = T;

Expand Down

0 comments on commit 9d4a9cf

Please sign in to comment.