-
-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c8b36f3
commit 89416d1
Showing
3 changed files
with
107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
--- | ||
title: 'Value' | ||
--- | ||
|
||
import { Callout } from 'nextra/components'; | ||
|
||
For ONNX Runtime, a **value** represents any type that can be given to/returned from a session or operator. Values come in three main types: | ||
- **Tensors** (multi-dimensional arrays). This is the most common type of `Value`. | ||
- **Maps** map a key type to a value type, similar to Rust's `HashMap<K, V>`. | ||
- **Sequences** are homogenously-typed dynamically-sized lists, similar to Rust's `Vec<T>`. The only values allowed in sequences are tensors, or maps of tensors. | ||
|
||
In order to actually use the data in these containers, you can use the `.try_extract_*` methods. `try_extract_tensor(_mut)` extracts an `ndarray::ArrayView(Mut)` from the value if it is a tensor. `try_extract_sequence` returns a `Vec` of values, and `try_extract_map` returns a `HashMap`. | ||
|
||
Sessions in `ort` return a map of `DynValue`s. You can determine a value's type via its `.dtype()` method. You can also use fallible methods to extract data from this value - for example, [`DynValue::try_extract_tensor`](https://ort.pyke.io/rustdoc/ort/type.DynValue.html#method.try_extract_tensor), which fails if the value is not a tensor. Often times though, you'll want to reuse the same value which you are certain is a tensor - in which case, you can **downcast** the value. | ||
|
||
## Downcasting | ||
**Downcasting** means to convert a `Dyn` type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`: | ||
```rs | ||
let value: ort::DynValue = outputs.remove("output0").unwrap(); | ||
|
||
let dyn_tensor: ort::DynTensor = value.downcast()?; | ||
``` | ||
|
||
If `value` is not actually a tensor, the `downcast()` call will fail. | ||
|
||
`DynTensor` allows you to use | ||
|
||
### Stronger types | ||
`DynTensor` means that the type **is** a tensor, but the *element type is unknown*. There are also `DynSequence`s and `DynMap`s, which have the same meaning - the element/key/value types are unknown. | ||
|
||
The strongly typed variants of these types - `Tensor<T>`, `Sequence<T>`, and `Map<K, V>`, can be directly downcasted to, too: | ||
```rs | ||
let dyn_value: ort::DynValue = outputs.remove("output0").unwrap(); | ||
|
||
let f32_tensor: ort::Tensor<f32> = dyn_value.downcast()?; | ||
``` | ||
|
||
If `value` is not a tensor, **or** if the element type of the value does not match what was requested (`f32`), the `downcast()` call will fail. | ||
|
||
Stronger typed values have infallible variants of the `.try_extract_*` methods: | ||
```rs | ||
// We could try to extract a tensor directly from a `DynValue`... | ||
let f32_array: ArrayViewD<f32> = dyn_value.try_extract_tensor()?; | ||
|
||
// Or, we can first onvert it to a tensor, and then extract afterwards: | ||
let tensor: ort::Tensor<f32> = dyn_value.downcast()?; | ||
let f32_array = tensor.extract_tensor(); // no `?` required, this will never fail! | ||
``` | ||
|
||
## Upcasting | ||
**Upcasting** means to convert a strongly-typed value type like `Tensor<f32>` to a weaker type like `DynTensor` or `DynValue`. This can be useful if you have code that stores values of different types, e.g. in a `HashMap<String, DynValue>`. | ||
|
||
Strongly-typed value types like `Tensor<f32>` can be converted into a `DynTensor` using `.upcast()`: | ||
```rs | ||
let dyn_tensor = f32_tensor.upcast(); | ||
// type is DynTensor | ||
``` | ||
|
||
`Tensor<f32>` or `DynTensor` can be cast to a `DynValue` by using `.into_dyn()`: | ||
```rs | ||
let dyn_value = f32_tensor.into_dyn(); | ||
// type is DynValue | ||
``` | ||
|
||
Upcasting a value doesn't change its underlying type; it just removes the specialization. You cannot, for example, upcast a `Tensor<f32>` to a `DynValue` and then downcast it to a `Sequence`; it's still a `Tensor<f32>`, just contained in a different type. | ||
|
||
## Conversion recap | ||
- `DynValue` represents a value that can be any type - tensor, sequence, or map. The type can be retrieved with `.dtype()`. | ||
- `DynTensor`, `DynMap`, and `DynSequence` are values with known container types, but unknown element types. | ||
- `Tensor<T>`, `Map<K, V>`, and `Sequence<T>` are values with known container and element types. | ||
- `Tensor<T>` and co. can be converted from/to their dyn types using `.downcast()`/`.upcast()`, respectively. | ||
- `Tensor<T>`/`DynTensor` and co. can be converted to `DynValue`s using `.into_dyn()`. | ||
|
||
<img width="100%" src="/assets/casting-map.png" alt="An illustration of the relationship between value types as described above, used for visualization purposes." /> | ||
|
||
<Callout type='info'> | ||
Note that `DynTensor` cannot be downcast to `Tensor<T>`, but `DynTensor` can be upcast to `DynValue` with `.into_dyn()`, and then downcast to `Tensor<T>` with `.downcast()`. | ||
|
||
Downcasts are cheap, as they only check the value's type. Upcasts compile to a no-op. | ||
</Callout> | ||
|
||
## Views | ||
A view (also called a ref) is functionally a borrowed variant of a value. There are also mutable views, which are equivalent to mutably borrowed values. Views are represented as separate structs so that they can be down/upcasted. | ||
|
||
View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respectively: | ||
- Tensors have `DynTensorRef(Mut)` and `TensorRef(Mut)`. | ||
- Maps have `DynMapRef(Mut)` and `MapRef(Mut)`. | ||
- Sequences have `DynSequenceRef(Mut)` and `SequenceRef(Mut)`. | ||
|
||
These views can be acquired with `.view()` or `.view_mut()` on a value type: | ||
```rs | ||
let my_tensor: ort::Tensor<f32> = Tensor::new(...)?; | ||
|
||
let tensor_view: ort::TensorRef<'_, f32> = my_tensor.view(); | ||
``` | ||
|
||
Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment). | ||
|
||
You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`: | ||
```rs | ||
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.downcast_ref()?; | ||
// is equivalent to | ||
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.view().downcast()?; | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.