diff --git a/rten-cli/src/main.rs b/rten-cli/src/main.rs index ad32b342..f2eb6be8 100644 --- a/rten-cli/src/main.rs +++ b/rten-cli/src/main.rs @@ -3,7 +3,8 @@ use std::error::Error; use std::time::Instant; use rten::{ - Dimension, InputOrOutput, Model, ModelMetadata, ModelOptions, NodeId, Output, RunOptions, + DataType, Dimension, InputOrOutput, Model, ModelMetadata, ModelOptions, NodeId, Output, + RunOptions, }; use rten_tensor::prelude::*; use rten_tensor::Tensor; @@ -242,6 +243,7 @@ fn run_with_random_input( let shape = info .shape() .ok_or(format!("Unable to get shape for input {}", name))?; + let dtype = info.dtype(); let resolved_shape: Vec = shape .iter() @@ -264,6 +266,10 @@ fn run_with_random_input( }) .collect(); + fn random_ints T>(shape: &[usize], gen: F) -> Output where Output: From> { + Tensor::from_simple_fn(shape, gen).into() + } + // Guess suitable content for the input based on its name. let tensor = match name { // If this is a mask, use all ones on the assumption that we @@ -283,12 +289,17 @@ fn run_with_random_input( // cached outputs from a previous run. "use_cache_branch" => Output::from(Tensor::from(0i32)), - // For anything else, random floats in [0, 1]. - // - // TODO - Value nodes in the model should include data types, - // so we can at least be sure to generate values of the correct - // type. - _ => Output::from(Tensor::from_simple_fn(&resolved_shape, || rng.f32())), + // For anything else, random values. + _ => match dtype { + // Generate floats in [0, 1] + Some(DataType::Float) | None => Output::from(Tensor::from_simple_fn(&resolved_shape, || rng.f32())), + // Generate random values for int types. The default ranges + // are intended to be suitable for many models, but there + // ought to be a way to override them. + Some(DataType::Int32) => random_ints(&resolved_shape, || rng.i32(0..256)), + Some(DataType::Int8) => random_ints(&resolved_shape, || rng.i8(0..=127)), + Some(DataType::UInt8) => random_ints(&resolved_shape, || rng.u8(0..=255)), + } }; inputs.push((id, tensor)); @@ -393,8 +404,11 @@ fn print_input_output_list(model: &Model, node_ids: &[NodeId]) { continue; }; println!( - " {}: {}", + " {}: {} {}", info.name().unwrap_or("(unnamed)"), + info.dtype() + .map(|dt| dt.to_string()) + .unwrap_or("(unknown dtype)".to_string()), info.shape() .map(|dims| format_shape(&dims)) .unwrap_or("(unknown shape)".to_string()) diff --git a/rten-convert/rten_convert/converter.py b/rten-convert/rten_convert/converter.py index 6fa326d0..d286803a 100644 --- a/rten-convert/rten_convert/converter.py +++ b/rten-convert/rten_convert/converter.py @@ -129,10 +129,18 @@ class ValueNode(Node): export time) sizes. """ - def __init__(self, name: str, shape: list[int | str] | None): + def __init__(self, name: str, shape: list[int | str] | None, dtype: int | None): + """ + Initialize a value node. + + :param name: Unique name of the value + :param shape: Expected shape of tensor at runtime + :param dtype: Expected data type of tensor at runtime. Value from `sg.DataType`. + """ super().__init__(name) self.shape = shape + self.dtype = dtype class Graph: @@ -503,11 +511,17 @@ def noop_add_node(node: Node) -> int: def value_node_from_onnx_value(value: onnx.ValueInfoProto) -> ValueNode: + if value.type.tensor_type.HasField("elem_type"): + dtype = convert_data_type(value.type.tensor_type.elem_type) + else: + dtype = None + if value.type.tensor_type.HasField("shape"): dims = [d.dim_param or d.dim_value for d in value.type.tensor_type.shape.dim] else: dims = None - return ValueNode(name=value.name, shape=dims) + + return ValueNode(name=value.name, shape=dims, dtype=dtype) class PadAttrs(Protocol): @@ -678,7 +692,11 @@ def op_node_from_onnx_operator( case "Cast": attrs = sg.CastAttrsT() - to = op_reader.get_attr("to", "int", TensorProto.DataType.FLOAT) # type:ignore[attr-defined] + to = op_reader.get_attr( + "to", + "int", + TensorProto.DataType.FLOAT, # type:ignore[attr-defined] + ) attrs.to = convert_data_type(to) case "Clip": @@ -1121,7 +1139,9 @@ def add_value_node(value: ValueInfoProto): if capture_ids is not None: for input_name in operator.input: if input_name not in value_name_to_index: - capture_id = add_node(ValueNode(name=input_name, shape=None)) + capture_id = add_node( + ValueNode(name=input_name, shape=None, dtype=None) + ) capture_ids.append(capture_id) for output_name in operator.output: @@ -1129,7 +1149,7 @@ def add_value_node(value: ValueInfoProto): # registered already. if output_name in value_name_to_index: continue - value_node = ValueNode(output_name, shape=None) + value_node = ValueNode(output_name, shape=None, dtype=None) add_node(value_node) try: @@ -1194,7 +1214,9 @@ def build_constant_node( inline_data_type = sg.ConstantData.UInt8Data dtype = sg.ConstantDataType.UInt8 case _: - raise ValueError(f"Unsupported data array type {constant.data.dtype.name}") # type:ignore[union-attr] + raise ValueError( + f"Unsupported data array type {constant.data.dtype.name}" # type:ignore[union-attr] + ) # Store inline if we're generating the V1 format, or the tensor is small. # Small values are mostly parameters such as axes, slice ranges etc. @@ -1347,6 +1369,8 @@ def write_dim(builder, dim: str | int) -> int: sg.ValueNodeStart(builder) if shape_vec: sg.ValueNodeAddShape(builder, shape_vec) + if value.dtype is not None: + sg.ValueNodeAddDtype(builder, value.dtype) return sg.ValueNodeEnd(builder) diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index bf865539..19d3c1ae 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -5921,8 +5921,15 @@ def ShapeIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) return o == 0 + # ValueNode + def Dtype(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos) + return None + def ValueNodeStart(builder): - builder.StartObject(1) + builder.StartObject(2) def ValueNodeAddShape(builder, shape): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) @@ -5930,6 +5937,9 @@ def ValueNodeAddShape(builder, shape): def ValueNodeStartShapeVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ValueNodeAddDtype(builder, dtype): + builder.PrependUint8Slot(1, dtype, None) + def ValueNodeEnd(builder): return builder.EndObject() @@ -5944,6 +5954,7 @@ class ValueNodeT(object): # ValueNodeT def __init__(self): self.shape = None # type: List[DimT] + self.dtype = None # type: Optional[int] @classmethod def InitFromBuf(cls, buf, pos): @@ -5974,6 +5985,7 @@ def _UnPack(self, valueNode): else: dim_ = DimT.InitFromObj(valueNode.Shape(i)) self.shape.append(dim_) + self.dtype = valueNode.Dtype() # ValueNodeT def Pack(self, builder): @@ -5988,6 +6000,7 @@ def Pack(self, builder): ValueNodeStart(builder) if self.shape is not None: ValueNodeAddShape(builder, shape) + ValueNodeAddDtype(builder, self.dtype) valueNode = ValueNodeEnd(builder) return valueNode diff --git a/src/graph.rs b/src/graph.rs index 7cf2fafa..33aa77ff 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -17,7 +17,9 @@ use smallvec::SmallVec; use crate::constant_storage::ArcTensorView; use crate::env::env_flag; -use crate::ops::{Input, InputList, InputOrOutput, OpError, Operator, Output, OutputList}; +use crate::ops::{ + DataType, Input, InputList, InputOrOutput, OpError, Operator, Output, OutputList, +}; use crate::tensor_pool::TensorPool; use crate::threading; use crate::timing::{InputShape, Instant, RunTiming, TimingRecord, TimingSort}; @@ -86,6 +88,7 @@ impl OperatorNode { pub struct ValueNode { name: Option, shape: Option>, + dtype: Option, } impl ValueNode { @@ -196,6 +199,15 @@ impl Constant { Constant::UInt8(i) => Input::UInt8Tensor(i.view()), } } + + fn dtype(&self) -> DataType { + match self { + Constant::Float(_) => DataType::Float, + Constant::Int32(_) => DataType::Int32, + Constant::Int8(_) => DataType::Int8, + Constant::UInt8(_) => DataType::UInt8, + } + } } macro_rules! impl_constant_node { @@ -280,6 +292,20 @@ impl Node { Node::Value(node) => node.shape.clone(), } } + + /// Return the data type associated with this node. + /// + /// - For constants this returns the element type of the tensor + /// - For values this returns the expected element type of the tensor at + /// runtime, if known + /// - For operators this always returns `None`. + pub fn dtype(&self) -> Option { + match self { + Node::Value(node) => node.dtype, + Node::Constant(constant) => Some(constant.dtype()), + Node::Operator(_) => None, + } + } } /// ID of a node in a [`Model`](crate::Model) graph. @@ -840,7 +866,7 @@ impl Graph { input_ids: &[NodeId], ) -> (NodeId, NodeId) { let op_out_name = format!("{}_out", name); - let op_out_id = self.add_value(Some(&op_out_name), None); + let op_out_id = self.add_value(Some(&op_out_name), None, None); let input_ids: Vec<_> = input_ids.iter().copied().map(Some).collect(); let op_node_id = self.add_op(Some(name), Box::new(op), &input_ids, &[op_out_id].map(Some)); (op_node_id, op_out_id) @@ -882,10 +908,16 @@ impl Graph { /// the graph is executed, such as an input or operator output. /// /// Returns the ID of the added node. - pub fn add_value(&mut self, name: Option<&str>, shape: Option>) -> NodeId { + pub fn add_value( + &mut self, + name: Option<&str>, + shape: Option>, + dtype: Option, + ) -> NodeId { self.add_node(Node::Value(ValueNode { name: name.map(|s| s.to_owned()), shape, + dtype, })) } @@ -1788,8 +1820,8 @@ mod tests { use super::{CachedPlan, CaptureEnv}; use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions, TypedConstant}; use crate::ops::{ - Add, Concat, Conv, Identity, If, InputList, IntoOpResult, Mul, OpError, Operator, Output, - OutputList, Relu, Shape, + Add, Concat, Conv, DataType, Identity, If, InputList, IntoOpResult, Mul, OpError, Operator, + Output, OutputList, Relu, Shape, }; use crate::tensor_pool::TensorPool; @@ -1870,7 +1902,7 @@ mod tests { ], ); let weights_id = g.add_constant(Some("weight"), weights); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, conv_out) = g.add_simple_op( "conv", @@ -1918,8 +1950,8 @@ mod tests { let weights = Tensor::from([0.3230]); let weights_id = g.add_constant(Some("weights"), weights.clone()); - let input_id = g.add_value(Some("input"), None); - let relu_out_id = g.add_value(Some("relu_out"), None); + let input_id = g.add_value(Some("input"), None, None); + let relu_out_id = g.add_value(Some("relu_out"), None, None); let relu_op_id = g.add_op( Some("relu"), Box::new(Relu {}), @@ -1932,8 +1964,8 @@ mod tests { assert_eq!(g.node_name(relu_op_id), "relu"); let anon_weights_id = g.add_constant(None, weights); - let anon_input_id = g.add_value(None, None); - let anon_out_id = g.add_value(None, None); + let anon_input_id = g.add_value(None, None, None); + let anon_out_id = g.add_value(None, None, None); let anon_op_id = g.add_op( None, Box::new(Relu {}), @@ -1969,6 +2001,7 @@ mod tests { ] .to_vec(), ), + None, ); let (relu_op_id, _) = g.add_simple_op("relu", Relu {}, &[input_id]); @@ -1991,6 +2024,21 @@ mod tests { assert_eq!(g.get_node(relu_op_id).and_then(|n| n.shape()), None); } + #[test] + fn test_graph_value_dtype() { + let mut g = Graph::new(); + for dtype in [ + DataType::Float, + DataType::Int32, + DataType::UInt8, + DataType::Int8, + ] { + let input_id = g.add_value(None, None, Some(dtype)); + let input_dtype = g.get_node(input_id).and_then(|n| n.dtype()); + assert_eq!(input_dtype, Some(dtype)); + } + } + #[derive(Debug)] struct AddOne {} impl Operator for AddOne { @@ -2009,7 +2057,7 @@ mod tests { fn test_graph_planning_order() -> Result<(), Box> { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, op_a_out) = g.add_simple_op("op_a", AddOne {}, &[input_id]); let (_, op_b_out) = g.add_simple_op("op_b", AddOne {}, &[op_a_out]); @@ -2042,8 +2090,8 @@ mod tests { fn test_runs_non_in_place_ops_first() -> Result<(), Box> { let mut g = Graph::new(); - let input_a_id = g.add_value(Some("input_a"), None); - let input_b_id = g.add_value(Some("input_b"), None); + let input_a_id = g.add_value(Some("input_a"), None, None); + let input_b_id = g.add_value(Some("input_b"), None, None); let (add_op, add_out) = g.add_simple_op("add", Add {}, &[input_a_id, input_b_id]); let (shape_op, shape_out) = g.add_simple_op("shape", Shape {}, &[input_a_id]); @@ -2069,7 +2117,7 @@ mod tests { fn test_graph_intermediate_output() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, op_a_out) = g.add_simple_op("op_a", AddOne {}, &[input_id]); let (_, op_b_out) = g.add_simple_op("op_b", AddOne {}, &[op_a_out]); @@ -2092,11 +2140,11 @@ mod tests { let mut g = Graph::new(); let input = Tensor::from([1., 2., 3., 4., 5.]); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let mut prev_output = input_id; for _ in 0..100 { - let next_output = g.add_value(None, None); + let next_output = g.add_value(None, None, None); g.add_op( None, Box::new(AddOne {}), @@ -2121,7 +2169,7 @@ mod tests { let mut g = Graph::new(); let input = Tensor::from([1., 2., 3., 4., 5.]); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let results = g .run(vec![(input_id, input.view().into())], &[input_id], None) @@ -2194,7 +2242,7 @@ mod tests { #[test] fn test_duplicate_inputs() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let input = Tensor::from([1.]); let result = g.run( vec![ @@ -2214,7 +2262,7 @@ mod tests { fn test_duplicate_outputs() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, op_a_out) = g.add_simple_op("op_a", AddOne {}, &[input_id]); let input = Tensor::from([1.]); @@ -2234,7 +2282,7 @@ mod tests { // Call an operator with an input omitted by setting it to `None`, // as opposed to passing a shorter input list. This enables omitting // an input but still providing subsequent ones. - let output = g.add_value(None, None); + let output = g.add_value(None, None, None); g.add_op(Some("shape"), Box::new(Shape {}), &[None], &[Some(output)]); let results = g.run(vec![], &[output], None); @@ -2307,7 +2355,7 @@ mod tests { #[test] fn test_runs_op_in_place() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, op1_out) = g.add_simple_op("op1", AddOneInPlace {}, &[input_id]); let (_, op2_out) = g.add_simple_op("op2", AddOneInPlace {}, &[op1_out]); @@ -2350,8 +2398,8 @@ mod tests { use crate::ops::Add; // A commutative operator let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); - let bias_id = g.add_value(Some("bias"), None); + let input_id = g.add_value(Some("input"), None, None); + let bias_id = g.add_value(Some("bias"), None, None); let op1 = TrackUsage::new(Add {}); let op1_metrics = op1.metrics(); @@ -2438,9 +2486,9 @@ mod tests { #[test] fn test_multiple_outputs() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); - let left_split_out = g.add_value(Some("left_split"), None); - let right_split_out = g.add_value(Some("right_split"), None); + let input_id = g.add_value(Some("input"), None, None); + let left_split_out = g.add_value(Some("left_split"), None, None); + let right_split_out = g.add_value(Some("right_split"), None, None); let split_op = Box::new(Split::new()); let run_count = split_op.run_count.clone(); @@ -2481,9 +2529,9 @@ mod tests { // operators. let mut g = Graph::new(); let const_0 = g.add_constant(Some("c0"), Tensor::from(3.)); - let val_0 = g.add_value(Some("i0"), None); + let val_0 = g.add_value(Some("i0"), None, None); let const_1 = g.add_constant(Some("c1"), Tensor::from(4.)); - let val_1 = g.add_value(Some("i1"), None); + let val_1 = g.add_value(Some("i1"), None, None); let (_, op_0_out) = g.add_simple_op("Add_0", Add {}, &[const_0, val_0]); let (_, op_1_out) = g.add_simple_op("Add_1", Add {}, &[const_1, val_1]); @@ -2655,25 +2703,25 @@ mod tests { #[test] fn test_subgraph() { let mut g = Graph::new(); - let input = g.add_value(Some("input"), None); + let input = g.add_value(Some("input"), None, None); // Add subgraphs for `If` operation. These capture `input`. let mut then_branch = Graph::new(); - let tb_input = then_branch.add_value(Some("input"), None); + let tb_input = then_branch.add_value(Some("input"), None, None); let two = then_branch.add_constant(None, Tensor::from(2.)); let (_, tb_output) = then_branch.add_simple_op("Mul", Mul {}, &[tb_input, two]); then_branch.set_captures(&[tb_input]); then_branch.set_output_ids(&[tb_output]); let mut else_branch = Graph::new(); - let eb_input = else_branch.add_value(Some("input"), None); + let eb_input = else_branch.add_value(Some("input"), None, None); let three = else_branch.add_constant(None, Tensor::from(3.)); let (_, eb_output) = else_branch.add_simple_op("Mul", Mul {}, &[eb_input, three]); else_branch.set_captures(&[eb_input]); else_branch.set_output_ids(&[eb_output]); // Add `If` operator that runs one of two subgraphs. - let cond = g.add_value(Some("cond"), None); + let cond = g.add_value(Some("cond"), None, None); let branch = If { then_branch, else_branch, @@ -2712,12 +2760,12 @@ mod tests { #[test] fn test_nested_subgraph() { let mut g = Graph::new(); - let input = g.add_value(Some("input"), None); + let input = g.add_value(Some("input"), None, None); let mut subgraph = Graph::new(); let mut nested_subgraph = Graph::new(); - let ns_input = nested_subgraph.add_value(Some("input"), None); + let ns_input = nested_subgraph.add_value(Some("input"), None, None); nested_subgraph.set_captures(&[ns_input]); nested_subgraph.set_output_ids(&[ns_input]); @@ -2742,7 +2790,7 @@ mod tests { #[test] fn test_captures_not_available_when_subgraph_is_run_directly() { let mut subgraph = Graph::new(); - let sg_input = subgraph.add_value(Some("input"), None); + let sg_input = subgraph.add_value(Some("input"), None, None); subgraph.set_captures(&[sg_input]); let (_, sg_add) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]); subgraph.set_output_ids(&[sg_add]); @@ -2769,10 +2817,10 @@ mod tests { #[test] fn test_partial_run_considers_subgraph_captures() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let mut subgraph = Graph::new(); - let sg_input = subgraph.add_value(Some("input"), None); + let sg_input = subgraph.add_value(Some("input"), None, None); subgraph.set_captures(&[sg_input]); let (_, sg_add) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]); subgraph.set_output_ids(&[sg_add]); @@ -2795,14 +2843,14 @@ mod tests { #[test] fn test_plan_considers_capture_dependencies() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, _) = g.add_simple_op("Add", Add {}, &[input_id, input_id]); // Add a subgraph with a captured value that is the output of an // operation in the parent graph. let mut subgraph = Graph::new(); - let sg_input = subgraph.add_value(Some("Add_out"), None); + let sg_input = subgraph.add_value(Some("Add_out"), None, None); subgraph.set_captures(&[sg_input]); let (_, sg_out) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]); subgraph.set_output_ids(&[sg_out]); @@ -2820,7 +2868,7 @@ mod tests { #[test] fn test_plan_considers_transitive_capture_dependencies() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let (_, _) = g.add_simple_op("Add", Add {}, &[input_id, input_id]); @@ -2828,7 +2876,7 @@ mod tests { // a dependency on an operator output in the top-level graph. let mut subgraph = Graph::new(); let mut nested_subgraph = Graph::new(); - let ns_input = nested_subgraph.add_value(Some("Add_out"), None); + let ns_input = nested_subgraph.add_value(Some("Add_out"), None, None); nested_subgraph.set_captures(&[ns_input]); let (_, ns_out) = nested_subgraph.add_simple_op("Id", Identity {}, &[ns_input]); nested_subgraph.set_output_ids(&[ns_out]); @@ -2855,7 +2903,7 @@ mod tests { #[test] fn test_keeps_temp_value_needed_as_subgraph_capture() { let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); // Compute a temporary `id_out` value and use it in the main graph. let (_, id_out) = g.add_simple_op("Id", Identity {}, &[input_id]); @@ -2865,7 +2913,7 @@ mod tests { // capture. Graph execution must keep the `id_out` value around until // this has run, even though no ops in the main graph need it as inputs. let mut subgraph = Graph::new(); - let sg_input = subgraph.add_value(Some("Id_out"), None); + let sg_input = subgraph.add_value(Some("Id_out"), None, None); subgraph.set_captures(&[sg_input]); let (_, sg_out) = subgraph.add_simple_op("Id", Identity {}, &[sg_input]); subgraph.set_output_ids(&[sg_out]); @@ -2885,10 +2933,10 @@ mod tests { // Set up a graph that runs a subgraph and passes captures by value, // if the value is passed to the graph as an owned value. let mut g = Graph::new(); - let input_id = g.add_value(Some("input"), None); + let input_id = g.add_value(Some("input"), None, None); let mut subgraph = Graph::new(); - let sg_input = subgraph.add_value(Some("input"), None); + let sg_input = subgraph.add_value(Some("input"), None, None); subgraph.set_captures(&[sg_input]); let id_op = TrackUsage::new(Identity {}); diff --git a/src/lib.rs b/src/lib.rs index 10817d3c..43fc04af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,7 +123,7 @@ pub use graph::{Dimension, NodeId, RunError, RunOptions}; pub use model::{Model, ModelLoadError, ModelOptions, NodeInfo}; pub use model_metadata::ModelMetadata; pub use op_registry::{OpRegistry, ReadOp, ReadOpError}; -pub use ops::{FloatOperators, Input, InputOrOutput, Operators, Output}; +pub use ops::{DataType, FloatOperators, Input, InputOrOutput, Operators, Output}; pub use tensor_pool::{ExtractBuffer, PoolRef, TensorPool}; pub use threading::{thread_pool, ThreadPool}; pub use timing::TimingSort; diff --git a/src/model.rs b/src/model.rs index 30b49e3c..8a33062a 100644 --- a/src/model.rs +++ b/src/model.rs @@ -21,8 +21,8 @@ use crate::graph::{ use crate::header::{Header, HeaderError}; use crate::model_metadata::ModelMetadata; use crate::number::{LeBytes, Pod}; -use crate::op_registry::{OpLoadContext, OpRegistry, ReadOpError}; -use crate::ops::{InputOrOutput, Output}; +use crate::op_registry::{convert_dtype, OpLoadContext, OpRegistry, ReadOpError}; +use crate::ops::{DataType, InputOrOutput, Output}; use crate::optimize::GraphOptimizer; use crate::schema_generated as sg; use crate::schema_generated::root_as_model; @@ -121,6 +121,14 @@ impl NodeInfo<'_> { pub fn shape(&self) -> Option> { self.node.shape() } + + /// Return the expected data type for this node at runtime. + /// + /// For constants the data type is always known. For values the data type + /// may be specified. For operators this always returns `None`. + pub fn dtype(&self) -> Option { + self.node.dtype() + } } /// Parse profiling flags from the `RTEN_TIMING` environment variable and @@ -516,7 +524,12 @@ impl Model { }) .collect() }); - let graph_node = graph.add_value(name, shape); + let dtype = value + .dtype() + .map(convert_dtype) + .transpose() + .map_err(ModelLoadError::OperatorInvalid)?; + let graph_node = graph.add_value(name, shape, dtype); Ok(graph_node) } @@ -846,7 +859,7 @@ mod tests { }; use crate::ops; use crate::ops::{ - BoxOrder, CoordTransformMode, NearestMode, OpError, Output, ResizeMode, Scalar, + BoxOrder, CoordTransformMode, DataType, NearestMode, OpError, Output, ResizeMode, Scalar, }; use crate::{ModelLoadError, OpRegistry, ReadOpError}; @@ -863,13 +876,14 @@ mod tests { .copied() .map(Dimension::Fixed) .collect(); - let input_node = graph_builder.add_value("input", Some(&input_shape)); - let output_node = graph_builder.add_value("output", None); + let input_node = + graph_builder.add_value("input", Some(&input_shape), Some(DataType::Float)); + let output_node = graph_builder.add_value("output", None, Some(DataType::Float)); graph_builder.add_input(input_node); graph_builder.add_output(output_node); - let concat_out = graph_builder.add_value("concat_out", None); + let concat_out = graph_builder.add_value("concat_out", None, None); graph_builder.add_operator( "concat", OpType::Concat(ops::Concat { axis: 0 }), @@ -959,6 +973,19 @@ mod tests { assert_eq!(shape, &[1, 2, 2].map(Dimension::Fixed)); } + #[test] + fn test_value_dtype_info() { + let buffer = generate_model_buffer(ModelFormat::V2); + let model = Model::load(buffer).unwrap(); + let input_id = model.input_ids()[0]; + + let dtype = model + .node_info(input_id) + .and_then(|ni| ni.dtype()) + .expect("input dtype missing"); + assert_eq!(dtype, DataType::Float); + } + #[test] fn test_metadata() { let buffer = generate_model_buffer(ModelFormat::V2); @@ -1132,7 +1159,7 @@ mod tests { let mut builder = ModelBuilder::new(ModelFormat::V2); let mut graph_builder = builder.graph_builder(); - let output_node = graph_builder.add_value("output", None); + let output_node = graph_builder.add_value("output", None, None); graph_builder.add_output(output_node); graph_builder.add_operator("shape", OpType::Shape, &[None], &[output_node]); @@ -1166,9 +1193,9 @@ mod tests { let mut builder = ModelBuilder::new(ModelFormat::V2); let mut graph_builder = builder.graph_builder(); - let input_node = graph_builder.add_value("input", None); - let input_2d = graph_builder.add_value("input.2d", None); - let input_bool = graph_builder.add_value("input.bool", None); + let input_node = graph_builder.add_value("input", None, None); + let input_2d = graph_builder.add_value("input.2d", None, None); + let input_bool = graph_builder.add_value("input.bool", None, None); // 4D shape used as the primary input to test most operators (eg. NCHW image). A few // require a different shape. @@ -1183,7 +1210,7 @@ mod tests { let mut add_operator = |builder: &mut GraphBuilder, name: &str, op: OpType, input_nodes: &[Option]| { let output_name = format!("{}_out", name); - let op_output_node = builder.add_value(&output_name, None); + let op_output_node = builder.add_value(&output_name, None, None); builder.add_operator(name, op, input_nodes, &[op_output_node]); op_outputs.push(output_name); op_output_node @@ -1447,9 +1474,9 @@ mod tests { }); } - let range_start_node = graph_builder.add_value("range_start", None); - let range_limit_node = graph_builder.add_value("range_limit", None); - let range_delta_node = graph_builder.add_value("range_delta", None); + let range_start_node = graph_builder.add_value("range_start", None, None); + let range_limit_node = graph_builder.add_value("range_limit", None, None); + let range_delta_node = graph_builder.add_value("range_delta", None, None); let range_out = add_operator!( Range, [range_start_node, range_limit_node, range_delta_node] @@ -1525,8 +1552,8 @@ mod tests { add_operator!(Squeeze, [input_node]); let split_splits = graph_builder.add_constant(Tensor::from([1, 2]).view()); - let split_out_1 = graph_builder.add_value("Split_out_1", None); - let split_out_2 = graph_builder.add_value("Split_out_2", None); + let split_out_1 = graph_builder.add_value("Split_out_1", None, None); + let split_out_2 = graph_builder.add_value("Split_out_2", None, None); graph_builder.add_operator( "Split", OpType::Split(ops::Split { axis: 1 }), @@ -1543,8 +1570,8 @@ mod tests { add_operator!(Tile, [input_node, tile_repeats]); let topk_k = graph_builder.add_constant(Tensor::from(3).view()); - let topk_out_values = graph_builder.add_value("TopK_out_values", None); - let topk_out_indices = graph_builder.add_value("TopK_out_indices", None); + let topk_out_values = graph_builder.add_value("TopK_out_values", None, None); + let topk_out_indices = graph_builder.add_value("TopK_out_indices", None, None); graph_builder.add_operator( "TopK", OpType::TopK(ops::TopK { @@ -1563,9 +1590,9 @@ mod tests { let unsqueeze_axes = graph_builder.add_constant(Tensor::from([0, 4]).view()); add_operator!(Unsqueeze, [input_node, unsqueeze_axes]); - let where_cond = graph_builder.add_value("where_cond", None); - let where_x = graph_builder.add_value("where_x", None); - let where_y = graph_builder.add_value("where_y", None); + let where_cond = graph_builder.add_value("where_cond", None, None); + let where_x = graph_builder.add_value("where_x", None, None); + let where_y = graph_builder.add_value("where_y", None, None); let where_out = add_operator!(Where, [where_cond, where_x, where_y]); add_operator!(Xor, [input_bool, input_bool]); diff --git a/src/model_builder.rs b/src/model_builder.rs index 4856e584..f315be47 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -223,6 +223,15 @@ fn pad_args_from_padding(padding: Padding) -> PadArgs { } } +fn convert_dtype(dtype: DataType) -> sg::DataType { + match dtype { + DataType::Int32 => sg::DataType::Int32, + DataType::Float => sg::DataType::Float, + DataType::Int8 => sg::DataType::Int8, + DataType::UInt8 => sg::DataType::UInt8, + } +} + /// Builder for serializing a graph or subgraph to FlatBuffers. pub struct GraphBuilder<'mb, 'a> { builder: &'mb mut FlatBufferBuilder<'a>, @@ -314,7 +323,12 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { } /// Add a value node to the model - pub fn add_value(&mut self, id: &str, shape: Option<&[Dimension]>) -> NodeId { + pub fn add_value( + &mut self, + id: &str, + shape: Option<&[Dimension]>, + dtype: Option, + ) -> NodeId { let shape = shape.map(|shape| { let dim_vec: Vec<_> = shape .iter() @@ -340,7 +354,8 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { .collect(); self.builder.create_vector(&dim_vec[..]) }); - let value_node = sg::ValueNode::create(self.builder, &sg::ValueNodeArgs { shape }); + let dtype = dtype.map(convert_dtype); + let value_node = sg::ValueNode::create(self.builder, &sg::ValueNodeArgs { shape, dtype }); self.add_node(Some(id), NodeData::Value(value_node)) } @@ -429,12 +444,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> { Cast, CastAttrs, sg::CastAttrsArgs { - to: match args.to { - DataType::Int32 => sg::DataType::Int32, - DataType::Float => sg::DataType::Float, - DataType::Int8 => sg::DataType::Int8, - DataType::UInt8 => sg::DataType::UInt8, - }, + to: convert_dtype(args.to), } ), OpType::Ceil => op!(Ceil), diff --git a/src/op_registry.rs b/src/op_registry.rs index ea12cc28..12441cde 100644 --- a/src/op_registry.rs +++ b/src/op_registry.rs @@ -226,7 +226,7 @@ impl Display for ReadOpError { impl Error for ReadOpError {} -fn convert_dtype(dtype: sg::DataType) -> Result { +pub fn convert_dtype(dtype: sg::DataType) -> Result { match dtype { sg::DataType::Int32 => Ok(DataType::Int32), sg::DataType::Float => Ok(DataType::Float), diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 7a94a6aa..52be7428 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -170,7 +170,8 @@ impl> From for Padding { } } -#[derive(Copy, Clone, Debug)] +/// Enum specifying the data type of a tensor. +#[derive(Copy, Clone, Debug, PartialEq)] pub enum DataType { Int32, Float, @@ -178,6 +179,23 @@ pub enum DataType { UInt8, } +impl std::fmt::Display for DataType { + /// Format this enum value in the style of the corresponding Rust type (eg. + /// "i32" for `DataType::Int32`). + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + DataType::Float => "f32", + DataType::Int32 => "i32", + DataType::Int8 => "i8", + DataType::UInt8 => "u8", + } + ) + } +} + /// Generate the body of a [`Layout`] impl for a type which wraps an /// underlying layout. macro_rules! impl_proxy_layout { diff --git a/src/optimize.rs b/src/optimize.rs index 47b6079d..de5f506c 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -96,7 +96,7 @@ impl GraphMutator { inputs: &[Option], op_output_id: Option, ) -> NodeId { - let op_output_id = op_output_id.unwrap_or(self.graph.add_value(None, None)); + let op_output_id = op_output_id.unwrap_or(self.graph.add_value(None, None, None)); let op_id = self.graph.add_op(name, op, inputs, &[Some(op_output_id)]); for input_id in inputs.iter().filter_map(|id| *id) { @@ -622,7 +622,7 @@ mod tests { // Capture the constant in the subgraph as a value. let mut subgraph = Graph::new(); - let sg_val = subgraph.add_value(Some("const_a"), None); + let sg_val = subgraph.add_value(Some("const_a"), None, None); subgraph.set_captures(&[sg_val]); subgraph.set_output_ids(&[sg_val]); @@ -652,7 +652,7 @@ mod tests { let (_, add_out) = graph.add_simple_op("add_1", Add {}, &[const_a, const_b]); // Add an operator with a dynamic input and the output of the previous operator. - let input = graph.add_value(Some("input"), None); + let input = graph.add_value(Some("input"), None, None); let (add_op_2, add_2_out) = graph.add_simple_op("add_2", Add {}, &[add_out, input]); graph.set_input_ids(&[input]); graph.set_output_ids(&[add_out, add_2_out]); @@ -703,8 +703,8 @@ mod tests { fn test_fuse_transpose() { let mut graph = Graph::new(); - let input_1 = graph.add_value(None, None); - let input_2 = graph.add_value(None, None); + let input_1 = graph.add_value(None, None, None); + let input_2 = graph.add_value(None, None, None); let (_, transpose_out) = graph.add_simple_op("transpose", Transpose { perm: None }, &[input_1]); @@ -723,7 +723,7 @@ mod tests { fn test_fuse_silu() { let mut graph = Graph::new(); - let input = graph.add_value(None, None); + let input = graph.add_value(None, None, None); let (_, sigmoid_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[input]); let (_, mul_out) = graph.add_simple_op("mul", Mul {}, &[input, sigmoid_out]); graph.set_input_ids(&[input]); @@ -741,7 +741,7 @@ mod tests { let mut graph = Graph::new(); // Add two consecutive decomposed Silu operations - let input = graph.add_value(None, None); + let input = graph.add_value(None, None, None); let (_, sigmoid_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[input]); let (_, mul_out) = graph.add_simple_op("mul", Mul {}, &[input, sigmoid_out]); let (_, sigmoid_2_out) = graph.add_simple_op("sigmoid", Sigmoid {}, &[mul_out]); @@ -769,7 +769,7 @@ mod tests { let one = graph.add_constant(None, Tensor::from(1.0)); let half = graph.add_constant(None, Tensor::from(0.5)); - let input = graph.add_value(None, None); + let input = graph.add_value(None, None, None); let (_, div_out) = graph.add_simple_op("div", Div {}, &[input, sqrt_2]); let (_, erf_out) = graph.add_simple_op("erf", Erf {}, &[div_out]); let (_, add_out) = graph.add_simple_op("add", Add {}, &[erf_out, one]); @@ -786,7 +786,7 @@ mod tests { fn layer_norm_graph() -> Graph { let mut graph = Graph::new(); - let input = graph.add_value(None, None); + let input = graph.add_value(None, None, None); // Center values let (_, mean_out) = graph.add_simple_op( @@ -843,8 +843,8 @@ mod tests { fn test_optimize_preserves_input_output_nodes() { let mut graph = Graph::new(); - let input_1 = graph.add_value(None, None); - let input_2 = graph.add_value(None, None); + let input_1 = graph.add_value(None, None, None); + let input_2 = graph.add_value(None, None, None); // Add fuse-able Transpose + MatMul let (_, transpose_out) = diff --git a/src/optimize/pattern_matcher.rs b/src/optimize/pattern_matcher.rs index f20b414d..af874298 100644 --- a/src/optimize/pattern_matcher.rs +++ b/src/optimize/pattern_matcher.rs @@ -350,7 +350,7 @@ mod tests { /// Create a graph that implements the softsign function `x / 1 + |x|`. fn softsign_graph() -> (Graph, NodeId, NodeId) { let mut graph = Graph::new(); - let input_id = graph.add_value(Some("x"), None); + let input_id = graph.add_value(Some("x"), None, None); let (_, abs_out) = graph.add_simple_op("abs", Abs {}, &[input_id]); let one = graph.add_constant(None, Tensor::from(1.0)); diff --git a/src/schema.fbs b/src/schema.fbs index 1b658644..0b7b3b42 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -550,6 +550,8 @@ table Dim { table ValueNode { // Expected shape of the tensor at runtime. shape:[Dim]; + // Expected data type of the tensor at runtime. + dtype:DataType = null; } table Node { diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 5f74b3cd..8247ea4b 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -9987,6 +9987,7 @@ impl<'a> flatbuffers::Follow<'a> for ValueNode<'a> { impl<'a> ValueNode<'a> { pub const VT_SHAPE: flatbuffers::VOffsetT = 4; + pub const VT_DTYPE: flatbuffers::VOffsetT = 6; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -10001,6 +10002,9 @@ impl<'a> ValueNode<'a> { if let Some(x) = args.shape { builder.add_shape(x); } + if let Some(x) = args.dtype { + builder.add_dtype(x); + } builder.finish() } @@ -10015,6 +10019,13 @@ impl<'a> ValueNode<'a> { >>(ValueNode::VT_SHAPE, None) } } + #[inline] + pub fn dtype(&self) -> Option { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { self._tab.get::(ValueNode::VT_DTYPE, None) } + } } impl flatbuffers::Verifiable for ValueNode<'_> { @@ -10028,6 +10039,7 @@ impl flatbuffers::Verifiable for ValueNode<'_> { .visit_field::>, >>("shape", Self::VT_SHAPE, false)? + .visit_field::("dtype", Self::VT_DTYPE, false)? .finish(); Ok(()) } @@ -10036,11 +10048,15 @@ pub struct ValueNodeArgs<'a> { pub shape: Option< flatbuffers::WIPOffset>>>, >, + pub dtype: Option, } impl<'a> Default for ValueNodeArgs<'a> { #[inline] fn default() -> Self { - ValueNodeArgs { shape: None } + ValueNodeArgs { + shape: None, + dtype: None, + } } } @@ -10060,6 +10076,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ValueNodeBuilder<'a, 'b, A> { .push_slot_always::>(ValueNode::VT_SHAPE, shape); } #[inline] + pub fn add_dtype(&mut self, dtype: DataType) { + self.fbb_ + .push_slot_always::(ValueNode::VT_DTYPE, dtype); + } + #[inline] pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a, A>) -> ValueNodeBuilder<'a, 'b, A> { let start = _fbb.start_table(); ValueNodeBuilder { @@ -10078,6 +10099,7 @@ impl core::fmt::Debug for ValueNode<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { let mut ds = f.debug_struct("ValueNode"); ds.field("shape", &self.shape()); + ds.field("dtype", &self.dtype()); ds.finish() } }