Skip to content

Commit

Permalink
docs: explain up/downcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Sep 24, 2024
1 parent c8b36f3 commit 89416d1
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/pages/_meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"setup": {
"title": "Setup"
},
"fundamentals": {
"title": "Fundamentals"
},
"perf": {
"title": "Performance"
},
Expand Down
104 changes: 104 additions & 0 deletions docs/pages/fundamentals/value.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
---
title: 'Value'
---

import { Callout } from 'nextra/components';

For ONNX Runtime, a **value** represents any type that can be given to/returned from a session or operator. Values come in three main types:
- **Tensors** (multi-dimensional arrays). This is the most common type of `Value`.
- **Maps** map a key type to a value type, similar to Rust's `HashMap<K, V>`.
- **Sequences** are homogenously-typed dynamically-sized lists, similar to Rust's `Vec<T>`. The only values allowed in sequences are tensors, or maps of tensors.

In order to actually use the data in these containers, you can use the `.try_extract_*` methods. `try_extract_tensor(_mut)` extracts an `ndarray::ArrayView(Mut)` from the value if it is a tensor. `try_extract_sequence` returns a `Vec` of values, and `try_extract_map` returns a `HashMap`.

Sessions in `ort` return a map of `DynValue`s. You can determine a value's type via its `.dtype()` method. You can also use fallible methods to extract data from this value - for example, [`DynValue::try_extract_tensor`](https://ort.pyke.io/rustdoc/ort/type.DynValue.html#method.try_extract_tensor), which fails if the value is not a tensor. Often times though, you'll want to reuse the same value which you are certain is a tensor - in which case, you can **downcast** the value.

## Downcasting
**Downcasting** means to convert a `Dyn` type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`:
```rs
let value: ort::DynValue = outputs.remove("output0").unwrap();

let dyn_tensor: ort::DynTensor = value.downcast()?;
```

If `value` is not actually a tensor, the `downcast()` call will fail.

`DynTensor` allows you to use

### Stronger types
`DynTensor` means that the type **is** a tensor, but the *element type is unknown*. There are also `DynSequence`s and `DynMap`s, which have the same meaning - the element/key/value types are unknown.

The strongly typed variants of these types - `Tensor<T>`, `Sequence<T>`, and `Map<K, V>`, can be directly downcasted to, too:
```rs
let dyn_value: ort::DynValue = outputs.remove("output0").unwrap();

let f32_tensor: ort::Tensor<f32> = dyn_value.downcast()?;
```

If `value` is not a tensor, **or** if the element type of the value does not match what was requested (`f32`), the `downcast()` call will fail.

Stronger typed values have infallible variants of the `.try_extract_*` methods:
```rs
// We could try to extract a tensor directly from a `DynValue`...
let f32_array: ArrayViewD<f32> = dyn_value.try_extract_tensor()?;

// Or, we can first onvert it to a tensor, and then extract afterwards:
let tensor: ort::Tensor<f32> = dyn_value.downcast()?;
let f32_array = tensor.extract_tensor(); // no `?` required, this will never fail!
```

## Upcasting
**Upcasting** means to convert a strongly-typed value type like `Tensor<f32>` to a weaker type like `DynTensor` or `DynValue`. This can be useful if you have code that stores values of different types, e.g. in a `HashMap<String, DynValue>`.

Strongly-typed value types like `Tensor<f32>` can be converted into a `DynTensor` using `.upcast()`:
```rs
let dyn_tensor = f32_tensor.upcast();
// type is DynTensor
```

`Tensor<f32>` or `DynTensor` can be cast to a `DynValue` by using `.into_dyn()`:
```rs
let dyn_value = f32_tensor.into_dyn();
// type is DynValue
```

Upcasting a value doesn't change its underlying type; it just removes the specialization. You cannot, for example, upcast a `Tensor<f32>` to a `DynValue` and then downcast it to a `Sequence`; it's still a `Tensor<f32>`, just contained in a different type.

## Conversion recap
- `DynValue` represents a value that can be any type - tensor, sequence, or map. The type can be retrieved with `.dtype()`.
- `DynTensor`, `DynMap`, and `DynSequence` are values with known container types, but unknown element types.
- `Tensor<T>`, `Map<K, V>`, and `Sequence<T>` are values with known container and element types.
- `Tensor<T>` and co. can be converted from/to their dyn types using `.downcast()`/`.upcast()`, respectively.
- `Tensor<T>`/`DynTensor` and co. can be converted to `DynValue`s using `.into_dyn()`.

<img width="100%" src="/assets/casting-map.png" alt="An illustration of the relationship between value types as described above, used for visualization purposes." />

<Callout type='info'>
Note that `DynTensor` cannot be downcast to `Tensor<T>`, but `DynTensor` can be upcast to `DynValue` with `.into_dyn()`, and then downcast to `Tensor<T>` with `.downcast()`.

Downcasts are cheap, as they only check the value's type. Upcasts compile to a no-op.
</Callout>

## Views
A view (also called a ref) is functionally a borrowed variant of a value. There are also mutable views, which are equivalent to mutably borrowed values. Views are represented as separate structs so that they can be down/upcasted.

View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respectively:
- Tensors have `DynTensorRef(Mut)` and `TensorRef(Mut)`.
- Maps have `DynMapRef(Mut)` and `MapRef(Mut)`.
- Sequences have `DynSequenceRef(Mut)` and `SequenceRef(Mut)`.

These views can be acquired with `.view()` or `.view_mut()` on a value type:
```rs
let my_tensor: ort::Tensor<f32> = Tensor::new(...)?;

let tensor_view: ort::TensorRef<'_, f32> = my_tensor.view();
```

Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment).

You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`:
```rs
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.downcast_ref()?;
// is equivalent to
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.view().downcast()?;
```
Binary file added docs/public/assets/casting-map.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 89416d1

Please sign in to comment.