Skip to content

Commit

Permalink
fix: swap upcast <-> downcast to match actual meaning
Browse files Browse the repository at this point in the history
closes #191
  • Loading branch information
decahedron1 committed Apr 25, 2024
1 parent 5227570 commit cedeb55
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 50 deletions.
10 changes: 5 additions & 5 deletions docs/migrating/v2.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ You can still extract tensors, maps, or sequence values normally from a `DynValu
let generated_tokens: ArrayViewD<f32> = outputs["output1"].try_extract_tensor()?;
```

`DynValue` can be `upcast()`ed to the more specialized types, like `DynMap` or `Tensor<T>`:
`DynValue` can be `downcast()`ed to the more specialized types, like `DynMap` or `Tensor<T>`:
```rust
let tensor: Tensor<f32> = value.upcast()?;
let map: DynMap = value.upcast()?;
let tensor: Tensor<f32> = value.downcast()?;
let map: DynMap = value.downcast()?;
```

Similarly, a strongly-typed value like `Tensor<T>` can be downcast back into a `DynValue` or `DynTensor`.
Similarly, a strongly-typed value like `Tensor<T>` can be upcast back into a `DynValue` or `DynTensor`.
```rust
let dyn_tensor: DynTensor = tensor.downcast();
let dyn_tensor: DynTensor = tensor.upcast();
let dyn_value: DynValue = tensor.into_dyn();
```

Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ pub use self::session::{
pub use self::tensor::ArrayExtensions;
pub use self::tensor::{IntoTensorElementType, TensorElementType};
pub use self::value::{
DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef,
DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, SequenceRef,
SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, UpcastableTarget, Value, ValueRef,
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,
SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef,
ValueRefMut, ValueType, ValueTypeMarker
};

Expand Down
6 changes: 3 additions & 3 deletions src/value/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT

/// Converts from a strongly-typed [`Map<K, V>`] to a type-erased [`DynMap`].
#[inline]
pub fn downcast(self) -> DynMap {
pub fn upcast(self) -> DynMap {
unsafe { std::mem::transmute(self) }
}

/// Converts from a strongly-typed [`Map<K, V>`] to a reference to a type-erased [`DynMap`].
#[inline]
pub fn downcast_ref(&self) -> DynMapRef {
pub fn upcast_ref(&self) -> DynMapRef {
DynMapRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand All @@ -156,7 +156,7 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT

/// Converts from a strongly-typed [`Map<K, V>`] to a mutable reference to a type-erased [`DynMap`].
#[inline]
pub fn downcast_mut(&mut self) -> DynMapRefMut {
pub fn upcast_mut(&mut self) -> DynMapRefMut {
DynMapRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand Down
22 changes: 11 additions & 11 deletions src/value/impl_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
ptr::{self, NonNull}
};

use super::{UpcastableTarget, ValueInner, ValueTypeMarker};
use super::{DowncastableTarget, ValueInner, ValueTypeMarker};
use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType};

pub trait SequenceValueTypeMarker: ValueTypeMarker {}
Expand All @@ -15,9 +15,9 @@ impl ValueTypeMarker for DynSequenceValueType {}
impl SequenceValueTypeMarker for DynSequenceValueType {}

#[derive(Debug)]
pub struct SequenceValueType<T: ValueTypeMarker + UpcastableTarget + Debug + ?Sized>(PhantomData<T>);
impl<T: ValueTypeMarker + UpcastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {}
impl<T: ValueTypeMarker + UpcastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {}
pub struct SequenceValueType<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(PhantomData<T>);
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {}

pub type DynSequence = Value<DynSequenceValueType>;
pub type Sequence<T> = Value<SequenceValueType<T>>;
Expand All @@ -28,7 +28,7 @@ pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType<T>>;
pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType<T>>;

impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + UpcastableTarget + Debug + Sized>(
pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(
&'s self,
allocator: &Allocator
) -> Result<Vec<ValueRef<'s, OtherType>>> {
Expand All @@ -47,7 +47,7 @@ impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
lifetime: PhantomData
};
let value_type = value.dtype()?;
if !OtherType::can_upcast(&value.dtype()?) {
if !OtherType::can_downcast(&value.dtype()?) {
return Err(Error::InvalidSequenceElementType { actual: value_type });
}

Expand All @@ -60,7 +60,7 @@ impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
}
}

impl<T: ValueTypeMarker + UpcastableTarget + Debug + Sized + 'static> Value<SequenceValueType<T>> {
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<SequenceValueType<T>> {
/// Creates a [`Sequence`] from an array of [`Value<T>`].
///
/// This `Value<T>` must be either a [`crate::Tensor`] or [`crate::Map`].
Expand Down Expand Up @@ -99,20 +99,20 @@ impl<T: ValueTypeMarker + UpcastableTarget + Debug + Sized + 'static> Value<Sequ
}
}

impl<T: ValueTypeMarker + UpcastableTarget + Debug + Sized> Value<SequenceValueType<T>> {
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValueType<T>> {
pub fn extract_sequence<'s>(&'s self, allocator: &Allocator) -> Vec<ValueRef<'s, T>> {
self.try_extract_sequence(allocator).expect("Failed to extract sequence")
}

/// Converts from a strongly-typed [`Sequence<T>`] to a type-erased [`DynSequence`].
#[inline]
pub fn downcast(self) -> DynSequence {
pub fn upcast(self) -> DynSequence {
unsafe { std::mem::transmute(self) }
}

/// Converts from a strongly-typed [`Sequence<T>`] to a reference to a type-erased [`DynTensor`].
#[inline]
pub fn downcast_ref(&self) -> DynSequenceRef {
pub fn upcast_ref(&self) -> DynSequenceRef {
DynSequenceRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand All @@ -123,7 +123,7 @@ impl<T: ValueTypeMarker + UpcastableTarget + Debug + Sized> Value<SequenceValueT

/// Converts from a strongly-typed [`Sequence<T>`] to a mutable reference to a type-erased [`DynTensor`].
#[inline]
pub fn downcast_mut(&mut self) -> DynSequenceRefMut {
pub fn upcast_mut(&mut self) -> DynSequenceRefMut {
DynSequenceRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand Down
8 changes: 4 additions & 4 deletions src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ where
{
type Error = Error;
fn try_from(arr: &'i CowArray<'v, T, D>) -> Result<Self, Self::Error> {
Tensor::from_array(arr).map(|c| c.downcast())
Tensor::from_array(arr).map(|c| c.upcast())
}
}

Expand All @@ -536,7 +536,7 @@ where
impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom<ArrayView<'v, T, D>> for DynTensor {
type Error = Error;
fn try_from(arr: ArrayView<'v, T, D>) -> Result<Self, Self::Error> {
Tensor::from_array(arr).map(|c| c.downcast())
Tensor::from_array(arr).map(|c| c.upcast())
}
}

Expand Down Expand Up @@ -573,7 +573,7 @@ macro_rules! impl_try_from {
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for DynTensor {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value).map(|c| c.downcast())
Tensor::from_array(value).map(|c| c.upcast())
}
}
impl<T: IntoTensorElementType + Debug + Clone + 'static, I: ToDimensions> TryFrom<$t> for crate::DynValue {
Expand All @@ -597,7 +597,7 @@ macro_rules! impl_try_from {
impl<T: IntoTensorElementType + Debug + Clone + 'static, D: ndarray::Dimension + 'static> TryFrom<$t> for DynTensor {
type Error = Error;
fn try_from(value: $t) -> Result<Self, Self::Error> {
Tensor::from_array(value).map(|c| c.downcast())
Tensor::from_array(value).map(|c| c.upcast())
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
Expand Down
16 changes: 8 additions & 8 deletions src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
ptr::NonNull
};

use super::{UpcastableTarget, Value, ValueInner, ValueTypeMarker};
use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker};
use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType};

pub trait TensorValueTypeMarker: ValueTypeMarker {}
Expand All @@ -31,8 +31,8 @@ pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>;
pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType<T>>;
pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType<T>>;

impl UpcastableTarget for DynTensorValueType {
fn can_upcast(dtype: &ValueType) -> bool {
impl DowncastableTarget for DynTensorValueType {
fn can_downcast(dtype: &ValueType) -> bool {
matches!(dtype, ValueType::Tensor { .. })
}
}
Expand Down Expand Up @@ -63,13 +63,13 @@ impl<Type: TensorValueTypeMarker + ?Sized> Value<Type> {
impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// Converts from a strongly-typed [`Tensor<T>`] to a type-erased [`DynTensor`].
#[inline]
pub fn downcast(self) -> DynTensor {
pub fn upcast(self) -> DynTensor {
unsafe { std::mem::transmute(self) }
}

/// Converts from a strongly-typed [`Tensor<T>`] to a reference to a type-erased [`DynTensor`].
#[inline]
pub fn downcast_ref(&self) -> DynTensorRef {
pub fn upcast_ref(&self) -> DynTensorRef {
DynTensorRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand All @@ -80,7 +80,7 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {

/// Converts from a strongly-typed [`Tensor<T>`] to a mutable reference to a type-erased [`DynTensor`].
#[inline]
pub fn downcast_mut(&mut self) -> DynTensorRefMut {
pub fn upcast_mut(&mut self) -> DynTensorRefMut {
DynTensorRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand All @@ -90,8 +90,8 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
}
}

impl<T: IntoTensorElementType + Debug> UpcastableTarget for TensorValueType<T> {
fn can_upcast(dtype: &ValueType) -> bool {
impl<T: IntoTensorElementType + Debug> DowncastableTarget for TensorValueType<T> {
fn can_downcast(dtype: &ValueType) -> bool {
match dtype {
ValueType::Tensor { ty, .. } => *ty == T::into_tensor_element_type(),
_ => false
Expand Down
34 changes: 18 additions & 16 deletions src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> {
///
/// ## Usage
/// You can access the data contained in a `Value` by using the relevant `extract` methods.
/// You can also use [`DynValue::upcast`] to attempt to convert from a [`DynValue`] to a more strongly typed value.
/// You can also use [`DynValue::downcast`] to attempt to convert from a [`DynValue`] to a more strongly typed value.
///
/// For dynamic values, where the type is not known at compile time, see the `try_extract_*` methods:
/// - [`Tensor::try_extract_tensor`], [`Tensor::try_extract_raw_tensor`]
/// - [`Sequence::try_extract_sequence`]
/// - [`Map::try_extract_map`]
///
/// If the type was created from Rust (via a method like [`Tensor::from_array`] or via upcasting), you can directly
/// If the type was created from Rust (via a method like [`Tensor::from_array`] or via downcasting), you can directly
/// extract the data using the infallible extract methods:
/// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`]
#[derive(Debug)]
Expand All @@ -267,7 +267,7 @@ pub struct Value<Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {

/// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`].
///
/// To attempt to convert a dynamic value to a strongly typed value, use [`DynValue::upcast`]. You can also attempt to
/// To attempt to convert a dynamic value to a strongly typed value, use [`DynValue::downcast`]. You can also attempt to
/// extract data from dynamic values directly using `try_extract_*` methods; see [`Value`] for more information.
pub type DynValue = Value<DynValueTypeMarker>;

Expand All @@ -277,14 +277,14 @@ pub type DynValue = Value<DynValueTypeMarker>;
/// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s.
pub trait ValueTypeMarker: Debug {}

/// Represents a type that a [`DynValue`] can be upcast to.
pub trait UpcastableTarget: ValueTypeMarker {
fn can_upcast(dtype: &ValueType) -> bool;
/// Represents a type that a [`DynValue`] can be downcast to.
pub trait DowncastableTarget: ValueTypeMarker {
fn can_downcast(dtype: &ValueType) -> bool;
}

// this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence`
impl UpcastableTarget for DynValueTypeMarker {
fn can_upcast(_: &ValueType) -> bool {
impl DowncastableTarget for DynValueTypeMarker {
fn can_downcast(_: &ValueType) -> bool {
true
}
}
Expand Down Expand Up @@ -406,21 +406,23 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
pub fn into_dyn(self) -> DynValue {
unsafe { std::mem::transmute(self) }
}
}

/// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant,
impl Value<DynValueTypeMarker> {
/// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant,
/// like [`Tensor<T>`].
#[inline]
pub fn upcast<OtherType: ValueTypeMarker + UpcastableTarget + Debug + ?Sized>(self) -> Result<Value<OtherType>> {
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(self) -> Result<Value<OtherType>> {
let dt = self.dtype()?;
if OtherType::can_upcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() }
if OtherType::can_downcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() }
}

/// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed reference
/// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed reference
/// variant, like [`TensorRef<T>`].
#[inline]
pub fn upcast_ref<OtherType: ValueTypeMarker + UpcastableTarget + Debug + ?Sized>(&self) -> Result<ValueRef<'_, OtherType>> {
pub fn downcast_ref<OtherType: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(&self) -> Result<ValueRef<'_, OtherType>> {
let dt = self.dtype()?;
if OtherType::can_upcast(&dt) {
if OtherType::can_downcast(&dt) {
Ok(ValueRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand All @@ -435,9 +437,9 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
/// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed
/// mutable-reference variant, like [`TensorRefMut<T>`].
#[inline]
pub fn upcast_mut<OtherType: ValueTypeMarker + UpcastableTarget + Debug + ?Sized>(&mut self) -> Result<ValueRefMut<'_, OtherType>> {
pub fn downcast_mut<OtherType: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(&mut self) -> Result<ValueRefMut<'_, OtherType>> {
let dt = self.dtype()?;
if OtherType::can_upcast(&dt) {
if OtherType::can_downcast(&dt) {
Ok(ValueRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
Expand Down

0 comments on commit cedeb55

Please sign in to comment.