Skip to content

Commit

Permalink
Merge pull request #420 from robertknight/value-dtype
Browse files Browse the repository at this point in the history
Add data type information to value nodes
  • Loading branch information
robertknight authored Nov 30, 2024
2 parents 2bfd667 + 20ab431 commit 35e119e
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 106 deletions.
30 changes: 22 additions & 8 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::error::Error;
use std::time::Instant;

use rten::{
Dimension, InputOrOutput, Model, ModelMetadata, ModelOptions, NodeId, Output, RunOptions,
DataType, Dimension, InputOrOutput, Model, ModelMetadata, ModelOptions, NodeId, Output,
RunOptions,
};
use rten_tensor::prelude::*;
use rten_tensor::Tensor;
Expand Down Expand Up @@ -242,6 +243,7 @@ fn run_with_random_input(
let shape = info
.shape()
.ok_or(format!("Unable to get shape for input {}", name))?;
let dtype = info.dtype();

let resolved_shape: Vec<usize> = shape
.iter()
Expand All @@ -264,6 +266,10 @@ fn run_with_random_input(
})
.collect();

fn random_ints<T, F: FnMut() -> T>(shape: &[usize], gen: F) -> Output where Output: From<Tensor<T>> {
Tensor::from_simple_fn(shape, gen).into()
}

// Guess suitable content for the input based on its name.
let tensor = match name {
// If this is a mask, use all ones on the assumption that we
Expand All @@ -283,12 +289,17 @@ fn run_with_random_input(
// cached outputs from a previous run.
"use_cache_branch" => Output::from(Tensor::from(0i32)),

// For anything else, random floats in [0, 1].
//
// TODO - Value nodes in the model should include data types,
// so we can at least be sure to generate values of the correct
// type.
_ => Output::from(Tensor::from_simple_fn(&resolved_shape, || rng.f32())),
// For anything else, random values.
_ => match dtype {
// Generate floats in [0, 1]
Some(DataType::Float) | None => Output::from(Tensor::from_simple_fn(&resolved_shape, || rng.f32())),
// Generate random values for int types. The default ranges
// are intended to be suitable for many models, but there
// ought to be a way to override them.
Some(DataType::Int32) => random_ints(&resolved_shape, || rng.i32(0..256)),
Some(DataType::Int8) => random_ints(&resolved_shape, || rng.i8(0..=127)),
Some(DataType::UInt8) => random_ints(&resolved_shape, || rng.u8(0..=255)),
}
};

inputs.push((id, tensor));
Expand Down Expand Up @@ -393,8 +404,11 @@ fn print_input_output_list(model: &Model, node_ids: &[NodeId]) {
continue;
};
println!(
" {}: {}",
" {}: {} {}",
info.name().unwrap_or("(unnamed)"),
info.dtype()
.map(|dt| dt.to_string())
.unwrap_or("(unknown dtype)".to_string()),
info.shape()
.map(|dims| format_shape(&dims))
.unwrap_or("(unknown shape)".to_string())
Expand Down
36 changes: 30 additions & 6 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,18 @@ class ValueNode(Node):
export time) sizes.
"""

def __init__(self, name: str, shape: list[int | str] | None):
def __init__(self, name: str, shape: list[int | str] | None, dtype: int | None):
"""
Initialize a value node.
:param name: Unique name of the value
:param shape: Expected shape of tensor at runtime
:param dtype: Expected data type of tensor at runtime. Value from `sg.DataType`.
"""
super().__init__(name)

self.shape = shape
self.dtype = dtype


class Graph:
Expand Down Expand Up @@ -503,11 +511,17 @@ def noop_add_node(node: Node) -> int:


def value_node_from_onnx_value(value: onnx.ValueInfoProto) -> ValueNode:
if value.type.tensor_type.HasField("elem_type"):
dtype = convert_data_type(value.type.tensor_type.elem_type)
else:
dtype = None

if value.type.tensor_type.HasField("shape"):
dims = [d.dim_param or d.dim_value for d in value.type.tensor_type.shape.dim]
else:
dims = None
return ValueNode(name=value.name, shape=dims)

return ValueNode(name=value.name, shape=dims, dtype=dtype)


class PadAttrs(Protocol):
Expand Down Expand Up @@ -678,7 +692,11 @@ def op_node_from_onnx_operator(

case "Cast":
attrs = sg.CastAttrsT()
to = op_reader.get_attr("to", "int", TensorProto.DataType.FLOAT) # type:ignore[attr-defined]
to = op_reader.get_attr(
"to",
"int",
TensorProto.DataType.FLOAT, # type:ignore[attr-defined]
)
attrs.to = convert_data_type(to)

case "Clip":
Expand Down Expand Up @@ -1121,15 +1139,17 @@ def add_value_node(value: ValueInfoProto):
if capture_ids is not None:
for input_name in operator.input:
if input_name not in value_name_to_index:
capture_id = add_node(ValueNode(name=input_name, shape=None))
capture_id = add_node(
ValueNode(name=input_name, shape=None, dtype=None)
)
capture_ids.append(capture_id)

for output_name in operator.output:
# If this output is also a model output, it will have been
# registered already.
if output_name in value_name_to_index:
continue
value_node = ValueNode(output_name, shape=None)
value_node = ValueNode(output_name, shape=None, dtype=None)
add_node(value_node)

try:
Expand Down Expand Up @@ -1194,7 +1214,9 @@ def build_constant_node(
inline_data_type = sg.ConstantData.UInt8Data
dtype = sg.ConstantDataType.UInt8
case _:
raise ValueError(f"Unsupported data array type {constant.data.dtype.name}") # type:ignore[union-attr]
raise ValueError(
f"Unsupported data array type {constant.data.dtype.name}" # type:ignore[union-attr]
)

# Store inline if we're generating the V1 format, or the tensor is small.
# Small values are mostly parameters such as axes, slice ranges etc.
Expand Down Expand Up @@ -1347,6 +1369,8 @@ def write_dim(builder, dim: str | int) -> int:
sg.ValueNodeStart(builder)
if shape_vec:
sg.ValueNodeAddShape(builder, shape_vec)
if value.dtype is not None:
sg.ValueNodeAddDtype(builder, value.dtype)
return sg.ValueNodeEnd(builder)


Expand Down
15 changes: 14 additions & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5921,15 +5921,25 @@ def ShapeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0

# ValueNode
def Dtype(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
return None

def ValueNodeStart(builder):
builder.StartObject(1)
builder.StartObject(2)

def ValueNodeAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)

def ValueNodeStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ValueNodeAddDtype(builder, dtype):
builder.PrependUint8Slot(1, dtype, None)

def ValueNodeEnd(builder):
return builder.EndObject()

Expand All @@ -5944,6 +5954,7 @@ class ValueNodeT(object):
# ValueNodeT
def __init__(self):
self.shape = None # type: List[DimT]
self.dtype = None # type: Optional[int]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down Expand Up @@ -5974,6 +5985,7 @@ def _UnPack(self, valueNode):
else:
dim_ = DimT.InitFromObj(valueNode.Shape(i))
self.shape.append(dim_)
self.dtype = valueNode.Dtype()

# ValueNodeT
def Pack(self, builder):
Expand All @@ -5988,6 +6000,7 @@ def Pack(self, builder):
ValueNodeStart(builder)
if self.shape is not None:
ValueNodeAddShape(builder, shape)
ValueNodeAddDtype(builder, self.dtype)
valueNode = ValueNodeEnd(builder)
return valueNode

Expand Down
Loading

0 comments on commit 35e119e

Please sign in to comment.