diff --git a/docs/migrating/v2.mdx b/docs/migrating/v2.mdx index f578517d..1359f03b 100644 --- a/docs/migrating/v2.mdx +++ b/docs/migrating/v2.mdx @@ -35,15 +35,15 @@ You can still extract tensors, maps, or sequence values normally from a `DynValu let generated_tokens: ArrayViewD = outputs["output1"].try_extract_tensor()?; ``` -`DynValue` can be `upcast()`ed to the more specialized types, like `DynMap` or `Tensor`: +`DynValue` can be `downcast()`ed to the more specialized types, like `DynMap` or `Tensor`: ```rust -let tensor: Tensor = value.upcast()?; -let map: DynMap = value.upcast()?; +let tensor: Tensor = value.downcast()?; +let map: DynMap = value.downcast()?; ``` -Similarly, a strongly-typed value like `Tensor` can be downcast back into a `DynValue` or `DynTensor`. +Similarly, a strongly-typed value like `Tensor` can be upcast back into a `DynValue` or `DynTensor`. ```rust -let dyn_tensor: DynTensor = tensor.downcast(); +let dyn_tensor: DynTensor = tensor.upcast(); let dyn_value: DynValue = tensor.into_dyn(); ``` diff --git a/src/lib.rs b/src/lib.rs index ca59170b..69d9b95c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,9 +64,9 @@ pub use self::session::{ pub use self::tensor::ArrayExtensions; pub use self::tensor::{IntoTensorElementType, TensorElementType}; pub use self::value::{ - DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, - DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, SequenceRef, - SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, UpcastableTarget, Value, ValueRef, + DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, + DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, + SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef, ValueRefMut, ValueType, ValueTypeMarker }; diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index 8cff3c03..1421e5a5 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -139,13 +139,13 @@ impl`] to a type-erased [`DynMap`]. #[inline] - pub fn downcast(self) -> DynMap { + pub fn upcast(self) -> DynMap { unsafe { std::mem::transmute(self) } } /// Converts from a strongly-typed [`Map`] to a reference to a type-erased [`DynMap`]. #[inline] - pub fn downcast_ref(&self) -> DynMapRef { + pub fn upcast_ref(&self) -> DynMapRef { DynMapRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -156,7 +156,7 @@ impl`] to a mutable reference to a type-erased [`DynMap`]. #[inline] - pub fn downcast_mut(&mut self) -> DynMapRefMut { + pub fn upcast_mut(&mut self) -> DynMapRefMut { DynMapRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index 8169157c..a94dcfab 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -4,7 +4,7 @@ use std::{ ptr::{self, NonNull} }; -use super::{UpcastableTarget, ValueInner, ValueTypeMarker}; +use super::{DowncastableTarget, ValueInner, ValueTypeMarker}; use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType}; pub trait SequenceValueTypeMarker: ValueTypeMarker {} @@ -15,9 +15,9 @@ impl ValueTypeMarker for DynSequenceValueType {} impl SequenceValueTypeMarker for DynSequenceValueType {} #[derive(Debug)] -pub struct SequenceValueType(PhantomData); -impl ValueTypeMarker for SequenceValueType {} -impl SequenceValueTypeMarker for SequenceValueType {} +pub struct SequenceValueType(PhantomData); +impl ValueTypeMarker for SequenceValueType {} +impl SequenceValueTypeMarker for SequenceValueType {} pub type DynSequence = Value; pub type Sequence = Value>; @@ -28,7 +28,7 @@ pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType>; pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType>; impl Value { - pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + UpcastableTarget + Debug + Sized>( + pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>( &'s self, allocator: &Allocator ) -> Result>> { @@ -47,7 +47,7 @@ impl Value { lifetime: PhantomData }; let value_type = value.dtype()?; - if !OtherType::can_upcast(&value.dtype()?) { + if !OtherType::can_downcast(&value.dtype()?) { return Err(Error::InvalidSequenceElementType { actual: value_type }); } @@ -60,7 +60,7 @@ impl Value { } } -impl Value> { +impl Value> { /// Creates a [`Sequence`] from an array of [`Value`]. /// /// This `Value` must be either a [`crate::Tensor`] or [`crate::Map`]. @@ -99,20 +99,20 @@ impl Value Value> { +impl Value> { pub fn extract_sequence<'s>(&'s self, allocator: &Allocator) -> Vec> { self.try_extract_sequence(allocator).expect("Failed to extract sequence") } /// Converts from a strongly-typed [`Sequence`] to a type-erased [`DynSequence`]. #[inline] - pub fn downcast(self) -> DynSequence { + pub fn upcast(self) -> DynSequence { unsafe { std::mem::transmute(self) } } /// Converts from a strongly-typed [`Sequence`] to a reference to a type-erased [`DynTensor`]. #[inline] - pub fn downcast_ref(&self) -> DynSequenceRef { + pub fn upcast_ref(&self) -> DynSequenceRef { DynSequenceRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -123,7 +123,7 @@ impl Value`] to a mutable reference to a type-erased [`DynTensor`]. #[inline] - pub fn downcast_mut(&mut self) -> DynSequenceRefMut { + pub fn upcast_mut(&mut self) -> DynSequenceRefMut { DynSequenceRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 37705a01..41eddbd5 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -527,7 +527,7 @@ where { type Error = Error; fn try_from(arr: &'i CowArray<'v, T, D>) -> Result { - Tensor::from_array(arr).map(|c| c.downcast()) + Tensor::from_array(arr).map(|c| c.upcast()) } } @@ -536,7 +536,7 @@ where impl<'v, T: IntoTensorElementType + Debug + Clone + 'static, D: Dimension + 'static> TryFrom> for DynTensor { type Error = Error; fn try_from(arr: ArrayView<'v, T, D>) -> Result { - Tensor::from_array(arr).map(|c| c.downcast()) + Tensor::from_array(arr).map(|c| c.upcast()) } } @@ -573,7 +573,7 @@ macro_rules! impl_try_from { impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { - Tensor::from_array(value).map(|c| c.downcast()) + Tensor::from_array(value).map(|c| c.upcast()) } } impl TryFrom<$t> for crate::DynValue { @@ -597,7 +597,7 @@ macro_rules! impl_try_from { impl TryFrom<$t> for DynTensor { type Error = Error; fn try_from(value: $t) -> Result { - Tensor::from_array(value).map(|c| c.downcast()) + Tensor::from_array(value).map(|c| c.upcast()) } } #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index da100929..d7f1db3b 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -8,7 +8,7 @@ use std::{ ptr::NonNull }; -use super::{UpcastableTarget, Value, ValueInner, ValueTypeMarker}; +use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker}; use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType}; pub trait TensorValueTypeMarker: ValueTypeMarker {} @@ -31,8 +31,8 @@ pub type DynTensorRefMut<'v> = ValueRefMut<'v, DynTensorValueType>; pub type TensorRef<'v, T> = ValueRef<'v, TensorValueType>; pub type TensorRefMut<'v, T> = ValueRefMut<'v, TensorValueType>; -impl UpcastableTarget for DynTensorValueType { - fn can_upcast(dtype: &ValueType) -> bool { +impl DowncastableTarget for DynTensorValueType { + fn can_downcast(dtype: &ValueType) -> bool { matches!(dtype, ValueType::Tensor { .. }) } } @@ -63,13 +63,13 @@ impl Value { impl Tensor { /// Converts from a strongly-typed [`Tensor`] to a type-erased [`DynTensor`]. #[inline] - pub fn downcast(self) -> DynTensor { + pub fn upcast(self) -> DynTensor { unsafe { std::mem::transmute(self) } } /// Converts from a strongly-typed [`Tensor`] to a reference to a type-erased [`DynTensor`]. #[inline] - pub fn downcast_ref(&self) -> DynTensorRef { + pub fn upcast_ref(&self) -> DynTensorRef { DynTensorRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -80,7 +80,7 @@ impl Tensor { /// Converts from a strongly-typed [`Tensor`] to a mutable reference to a type-erased [`DynTensor`]. #[inline] - pub fn downcast_mut(&mut self) -> DynTensorRefMut { + pub fn upcast_mut(&mut self) -> DynTensorRefMut { DynTensorRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -90,8 +90,8 @@ impl Tensor { } } -impl UpcastableTarget for TensorValueType { - fn can_upcast(dtype: &ValueType) -> bool { +impl DowncastableTarget for TensorValueType { + fn can_downcast(dtype: &ValueType) -> bool { match dtype { ValueType::Tensor { ty, .. } => *ty == T::into_tensor_element_type(), _ => false diff --git a/src/value/mod.rs b/src/value/mod.rs index 60e5e046..61b49d18 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -249,14 +249,14 @@ impl<'v, Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'v, Type> { /// /// ## Usage /// You can access the data contained in a `Value` by using the relevant `extract` methods. -/// You can also use [`DynValue::upcast`] to attempt to convert from a [`DynValue`] to a more strongly typed value. +/// You can also use [`DynValue::downcast`] to attempt to convert from a [`DynValue`] to a more strongly typed value. /// /// For dynamic values, where the type is not known at compile time, see the `try_extract_*` methods: /// - [`Tensor::try_extract_tensor`], [`Tensor::try_extract_raw_tensor`] /// - [`Sequence::try_extract_sequence`] /// - [`Map::try_extract_map`] /// -/// If the type was created from Rust (via a method like [`Tensor::from_array`] or via upcasting), you can directly +/// If the type was created from Rust (via a method like [`Tensor::from_array`] or via downcasting), you can directly /// extract the data using the infallible extract methods: /// - [`Tensor::extract_tensor`], [`Tensor::extract_raw_tensor`] #[derive(Debug)] @@ -267,7 +267,7 @@ pub struct Value { /// A dynamic value, which could be a [`Tensor`], [`Sequence`], or [`Map`]. /// -/// To attempt to convert a dynamic value to a strongly typed value, use [`DynValue::upcast`]. You can also attempt to +/// To attempt to convert a dynamic value to a strongly typed value, use [`DynValue::downcast`]. You can also attempt to /// extract data from dynamic values directly using `try_extract_*` methods; see [`Value`] for more information. pub type DynValue = Value; @@ -277,14 +277,14 @@ pub type DynValue = Value; /// inherits this trait), i.e. [`Tensor`]s, [`DynTensor`]s, and [`DynValue`]s. pub trait ValueTypeMarker: Debug {} -/// Represents a type that a [`DynValue`] can be upcast to. -pub trait UpcastableTarget: ValueTypeMarker { - fn can_upcast(dtype: &ValueType) -> bool; +/// Represents a type that a [`DynValue`] can be downcast to. +pub trait DowncastableTarget: ValueTypeMarker { + fn can_downcast(dtype: &ValueType) -> bool; } // this implementation is used in case we want to extract `DynValue`s from a [`Sequence`]; see `try_extract_sequence` -impl UpcastableTarget for DynValueTypeMarker { - fn can_upcast(_: &ValueType) -> bool { +impl DowncastableTarget for DynValueTypeMarker { + fn can_downcast(_: &ValueType) -> bool { true } } @@ -406,21 +406,23 @@ impl Value { pub fn into_dyn(self) -> DynValue { unsafe { std::mem::transmute(self) } } +} - /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant, +impl Value { + /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed variant, /// like [`Tensor`]. #[inline] - pub fn upcast(self) -> Result> { + pub fn downcast(self) -> Result> { let dt = self.dtype()?; - if OtherType::can_upcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } + if OtherType::can_downcast(&dt) { Ok(unsafe { std::mem::transmute(self) }) } else { panic!() } } - /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed reference + /// Attempts to downcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed reference /// variant, like [`TensorRef`]. #[inline] - pub fn upcast_ref(&self) -> Result> { + pub fn downcast_ref(&self) -> Result> { let dt = self.dtype()?; - if OtherType::can_upcast(&dt) { + if OtherType::can_downcast(&dt) { Ok(ValueRef::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()), @@ -435,9 +437,9 @@ impl Value { /// Attempts to upcast a dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed /// mutable-reference variant, like [`TensorRefMut`]. #[inline] - pub fn upcast_mut(&mut self) -> Result> { + pub fn downcast_mut(&mut self) -> Result> { let dt = self.dtype()?; - if OtherType::can_upcast(&dt) { + if OtherType::can_downcast(&dt) { Ok(ValueRefMut::new(unsafe { Value::from_ptr_nodrop( NonNull::new_unchecked(self.ptr()),