From ba8ac024459ade03fe8c07d6809b467b632db1b7 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 30 Mar 2024 07:16:08 +0000 Subject: [PATCH 1/3] Output more helpful error if operator is not available When loading a model with an operator that is supported by RTen but not enabled for the current model or rten crate features, include the name of the operator in the error. This makes rectifying the problem much easier. --- src/lib.rs | 2 +- src/model.rs | 30 +++++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e8df6c7c..789ac669 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,7 @@ pub mod ctc; pub mod ops; pub use graph::{Dimension, NodeId, RunOptions}; -pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry}; +pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry, ReadOpError}; pub use model_metadata::ModelMetadata; pub use ops::{FloatOperators, Input, Operators, Output}; pub use timer::Timer; diff --git a/src/model.rs b/src/model.rs index e780defa..ad3236fc 100644 --- a/src/model.rs +++ b/src/model.rs @@ -417,7 +417,11 @@ impl OpRegistry { fn read_op(&self, op: &OperatorNode) -> ReadOpResult { self.ops .get(&op.type_()) - .ok_or(ReadOpError::UnsupportedOperator) + .ok_or_else(|| { + ReadOpError::UnsupportedOperator( + op.type_().variant_name().unwrap_or("(unknown)").to_string(), + ) + }) .and_then(|read_fn| read_fn(op)) } @@ -542,19 +546,21 @@ impl OpRegistry { } /// Error type for errors that occur when de-serializing an operator. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum ReadOpError { /// The operator attributes were missing or of the wrong type. AttrError, /// The operator type is incorrect or unsupported. - UnsupportedOperator, + UnsupportedOperator(String), } impl Display for ReadOpError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { ReadOpError::AttrError => write!(f, "invalid attributes for operator"), - ReadOpError::UnsupportedOperator => write!(f, "unsupported operator"), + ReadOpError::UnsupportedOperator(name) => { + write!(f, "operator {name} is not supported or not enabled") + } } } } @@ -944,7 +950,7 @@ fn read_trilu_op(node: &OperatorNode) -> ReadOpResult { } /// Errors reported by [Model::load]. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum ModelLoadError { SchemaVersionUnsupported, @@ -1145,6 +1151,7 @@ mod tests { use crate::model_builder::{MetadataArgs, ModelBuilder, OpType}; use crate::ops; use crate::ops::{BoxOrder, CoordTransformMode, NearestMode, OpError, ResizeMode, Scalar}; + use crate::{ModelLoadError, OpRegistry, ReadOpError}; fn generate_model_buffer() -> Vec { let mut builder = ModelBuilder::new(); @@ -1205,6 +1212,19 @@ mod tests { ); } + #[test] + fn test_unsupported_operator() { + let buffer = generate_model_buffer(); + let registry = OpRegistry::new(); + let result = Model::load_with_ops(&buffer, ®istry); + assert_eq!( + result.err(), + Some(ModelLoadError::OperatorInvalid( + ReadOpError::UnsupportedOperator("Concat".to_string()) + )) + ); + } + #[test] fn test_shape_info() { let buffer = generate_model_buffer(); From 6212df4d08ce2b408c7108fb687db24f3ae648d6 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 29 Mar 2024 09:08:31 +0000 Subject: [PATCH 2/3] Implement RandomUniform operator This adds a new dependency on fastrand. Most models won't need random number generation operators, so make this an optional feature. --- Cargo.lock | 18 +- Cargo.toml | 5 + rten-cli/Cargo.toml | 4 +- rten-convert/rten_convert/converter.py | 9 + rten-convert/rten_convert/schema_generated.py | 164 ++++++++++++- rten-examples/Cargo.toml | 2 +- rten-vecmath/Cargo.toml | 2 +- src/model.rs | 33 +++ src/model_builder.rs | 21 ++ src/ops/mod.rs | 13 + src/ops/random.rs | 147 ++++++++++++ src/schema.fbs | 9 + src/schema_generated.rs | 224 +++++++++++++++++- 13 files changed, 626 insertions(+), 25 deletions(-) create mode 100644 src/ops/random.rs diff --git a/Cargo.lock b/Cargo.lock index 775c9106..18fd08eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,12 +100,9 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "fastrand" -version = "1.9.0" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" [[package]] name = "fdeflate" @@ -157,15 +154,6 @@ dependencies = [ "png", ] -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "itoa" version = "1.0.9" @@ -309,9 +297,11 @@ dependencies = [ name = "rten" version = "0.5.0" dependencies = [ + "fastrand", "flatbuffers", "libm", "rayon", + "rten", "rten-bench", "rten-tensor", "rten-vecmath", diff --git a/Cargo.toml b/Cargo.toml index bf578c40..8cbb2e57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ description = "Machine learning runtime" license = "MIT OR Apache-2.0" homepage = "https://github.com/robertknight/rten" repository = "https://github.com/robertknight/rten" +resolver = "2" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -37,8 +38,10 @@ rayon = "1.7.0" smallvec = { version = "1.10.0", features = ["union", "const_generics", "const_new"] } rten-tensor = { path = "./rten-tensor", version = "0.4.0" } rten-vecmath = { path = "./rten-vecmath", version = "0.4.0" } +fastrand = { version = "2.0.2", optional = true } [dev-dependencies] +rten = { path = ".", features = ["random"] } rten-bench = { path = "./rten-bench" } serde_json = "1.0.91" @@ -50,6 +53,8 @@ crate-type = ["lib", "cdylib"] avx512 = ["rten-vecmath/avx512"] # Generate WebAssembly API using wasm-bindgen. wasm_api = [] +# Enable operators that generate random numbers. +random = ["fastrand"] [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2.83" diff --git a/rten-cli/Cargo.toml b/rten-cli/Cargo.toml index 45dc0589..4e443769 100644 --- a/rten-cli/Cargo.toml +++ b/rten-cli/Cargo.toml @@ -9,8 +9,8 @@ homepage = "https://github.com/robertknight/rten" repository = "https://github.com/robertknight/rten" [dependencies] -fastrand = "1.9.0" -rten = { path = "../", version = "0.5.0" } +fastrand = "2.0.2" +rten = { path = "../", version = "0.5.0", features=["random"] } rten-tensor = { path = "../rten-tensor", version = "0.4.0" } lexopt = "0.3.0" diff --git a/rten-convert/rten_convert/converter.py b/rten-convert/rten_convert/converter.py index e7612845..e8d0a199 100644 --- a/rten-convert/rten_convert/converter.py +++ b/rten-convert/rten_convert/converter.py @@ -798,6 +798,15 @@ def op_node_from_onnx_operator( attrs = sg.OneHotAttrsT() attrs.axis = op_reader.get_attr("axis", "int", -1) + case "RandomUniform": + attrs = sg.RandomUniformAttrsT() + op_reader.check_attr("dtype", "int", 1) + + attrs.seed = op_reader.get_attr("seed", "float", None) + attrs.shape = op_reader.require_attr("shape", "ints") + attrs.low = op_reader.get_attr("low", "float", 0.0) + attrs.high = op_reader.get_attr("high", "float", 1.0) + case ( "ReduceL2" | "ReduceMax" diff --git a/rten-convert/rten_convert/schema_generated.py b/rten-convert/rten_convert/schema_generated.py index 6b2bee58..f54338df 100644 --- a/rten-convert/rten_convert/schema_generated.py +++ b/rten-convert/rten_convert/schema_generated.py @@ -102,6 +102,7 @@ class OperatorType(object): GatherElements = 92 LayerNormalization = 93 ReduceSumSquare = 94 + RandomUniform = 95 class RNNDirection(object): @@ -170,6 +171,7 @@ class OperatorAttrs(object): ScatterNDAttrs = 28 NonMaxSuppressionAttrs = 29 LayerNormalizationAttrs = 30 + RandomUniformAttrs = 31 def OperatorAttrsCreator(unionType, table): from flatbuffers.table import Table @@ -235,6 +237,8 @@ def OperatorAttrsCreator(unionType, table): return NonMaxSuppressionAttrsT.InitFromBuf(table.Bytes, table.Pos) if unionType == OperatorAttrs().LayerNormalizationAttrs: return LayerNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos) + if unionType == OperatorAttrs().RandomUniformAttrs: + return RandomUniformAttrsT.InitFromBuf(table.Bytes, table.Pos) return None @@ -2675,6 +2679,164 @@ def Pack(self, builder): return oneHotAttrs +class RandomUniformAttrs(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = RandomUniformAttrs() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsRandomUniformAttrs(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def RandomUniformAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed) + + # RandomUniformAttrs + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # RandomUniformAttrs + def Shape(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + a = self._tab.Vector(o) + return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4)) + return 0 + + # RandomUniformAttrs + def ShapeAsNumpy(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o) + return 0 + + # RandomUniformAttrs + def ShapeLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # RandomUniformAttrs + def ShapeIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + return o == 0 + + # RandomUniformAttrs + def High(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # RandomUniformAttrs + def Low(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return 0.0 + + # RandomUniformAttrs + def Seed(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) + return None + +def RandomUniformAttrsStart(builder): + builder.StartObject(4) + +def RandomUniformAttrsAddShape(builder, shape): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0) + +def RandomUniformAttrsStartShapeVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + +def RandomUniformAttrsAddHigh(builder, high): + builder.PrependFloat32Slot(1, high, 0.0) + +def RandomUniformAttrsAddLow(builder, low): + builder.PrependFloat32Slot(2, low, 0.0) + +def RandomUniformAttrsAddSeed(builder, seed): + builder.PrependFloat32Slot(3, seed, None) + +def RandomUniformAttrsEnd(builder): + return builder.EndObject() + + +try: + from typing import List +except: + pass + +class RandomUniformAttrsT(object): + + # RandomUniformAttrsT + def __init__(self): + self.shape = None # type: List[int] + self.high = 0.0 # type: float + self.low = 0.0 # type: float + self.seed = None # type: Optional[float] + + @classmethod + def InitFromBuf(cls, buf, pos): + randomUniformAttrs = RandomUniformAttrs() + randomUniformAttrs.Init(buf, pos) + return cls.InitFromObj(randomUniformAttrs) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos+n) + + @classmethod + def InitFromObj(cls, randomUniformAttrs): + x = RandomUniformAttrsT() + x._UnPack(randomUniformAttrs) + return x + + # RandomUniformAttrsT + def _UnPack(self, randomUniformAttrs): + if randomUniformAttrs is None: + return + if not randomUniformAttrs.ShapeIsNone(): + if np is None: + self.shape = [] + for i in range(randomUniformAttrs.ShapeLength()): + self.shape.append(randomUniformAttrs.Shape(i)) + else: + self.shape = randomUniformAttrs.ShapeAsNumpy() + self.high = randomUniformAttrs.High() + self.low = randomUniformAttrs.Low() + self.seed = randomUniformAttrs.Seed() + + # RandomUniformAttrsT + def Pack(self, builder): + if self.shape is not None: + if np is not None and type(self.shape) is np.ndarray: + shape = builder.CreateNumpyVector(self.shape) + else: + RandomUniformAttrsStartShapeVector(builder, len(self.shape)) + for i in reversed(range(len(self.shape))): + builder.PrependUint32(self.shape[i]) + shape = builder.EndVector() + RandomUniformAttrsStart(builder) + if self.shape is not None: + RandomUniformAttrsAddShape(builder, shape) + RandomUniformAttrsAddHigh(builder, self.high) + RandomUniformAttrsAddLow(builder, self.low) + RandomUniformAttrsAddSeed(builder, self.seed) + randomUniformAttrs = RandomUniformAttrsEnd(builder) + return randomUniformAttrs + + class ReduceMeanAttrs(object): __slots__ = ['_tab'] @@ -3746,7 +3908,7 @@ class OperatorNodeT(object): def __init__(self): self.type = 0 # type: int self.attrsType = 0 # type: int - self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT] + self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT] self.inputs = None # type: List[int] self.outputs = None # type: List[int] diff --git a/rten-examples/Cargo.toml b/rten-examples/Cargo.toml index c36b055e..4090929e 100644 --- a/rten-examples/Cargo.toml +++ b/rten-examples/Cargo.toml @@ -9,7 +9,7 @@ homepage = "https://github.com/robertknight/rten" repository = "https://github.com/robertknight/rten" [dependencies] -fastrand = "1.9.0" +fastrand = "2.0.2" hound = "3.5.1" image = { version = "0.24.6", default-features = false, features = ["png", "jpeg", "jpeg_rayon", "webp"] } lexopt = "0.3.0" diff --git a/rten-vecmath/Cargo.toml b/rten-vecmath/Cargo.toml index e1cd381a..5b221244 100644 --- a/rten-vecmath/Cargo.toml +++ b/rten-vecmath/Cargo.toml @@ -9,7 +9,7 @@ homepage = "https://github.com/robertknight/rten" repository = "https://github.com/robertknight/rten" [dev-dependencies] -fastrand = "1.9.0" +fastrand = "2.0.2" libm = "0.2.6" [lib] diff --git a/src/model.rs b/src/model.rs index ad3236fc..e84500cd 100644 --- a/src/model.rs +++ b/src/model.rs @@ -349,6 +349,10 @@ impl_default_factory!(OneHot, read_onehot_op); impl_default_factory!(Or); impl_default_factory!(Pad); impl_default_factory!(Pow); + +#[cfg(feature = "random")] +impl_default_factory!(RandomUniform, read_random_uniform_op); + impl_default_factory!(Range); impl_default_factory!(Reciprocal); impl_default_factory!(ReduceL2, read_reduce_l2_op); @@ -504,6 +508,10 @@ impl OpRegistry { register_op!(Or); register_op!(Pad); register_op!(Pow); + + #[cfg(feature = "random")] + register_op!(RandomUniform); + register_op!(Range); register_op!(Reciprocal); register_op!(ReduceL2); @@ -818,6 +826,24 @@ fn read_non_max_suppression_op(node: &OperatorNode) -> ReadOpResult { read_axis_op!(read_onehot_op, attrs_as_one_hot_attrs, OneHot); +#[cfg(feature = "random")] +fn read_random_uniform_op(node: &OperatorNode) -> ReadOpResult { + let attrs = node + .attrs_as_random_uniform_attrs() + .ok_or(ReadOpError::AttrError)?; + let shape = attrs + .shape() + .map(|shape| shape.iter().map(|size| size as usize).collect()) + .unwrap_or(vec![]); + + Ok(Box::new(ops::RandomUniform { + shape, + high: attrs.high(), + low: attrs.low(), + seed: attrs.seed(), + })) +} + fn read_reduce_attrs(node: &OperatorNode) -> Result<(Option>, bool), ReadOpError> { let attrs = node .attrs_as_reduce_mean_attrs() @@ -1530,6 +1556,13 @@ mod tests { add_operator!(Pad, [input_node, pads]); add_operator!(Pow, [input_node, input_node]); + add_operator!(RandomUniform, [], { + shape: vec![50, 50], + low: 0., + high: 1., + seed: None, + }); + let range_start_node = builder.add_value("range_start", None); let range_limit_node = builder.add_value("range_limit", None); let range_delta_node = builder.add_value("range_delta", None); diff --git a/src/model_builder.rs b/src/model_builder.rs index d97c65e0..50279045 100644 --- a/src/model_builder.rs +++ b/src/model_builder.rs @@ -15,6 +15,9 @@ use crate::ops::{ }; use crate::schema_generated as sg; +#[cfg(feature = "random")] +use crate::ops::RandomUniform; + /// Enum of all the built-in operators pub enum OpType { Abs, @@ -73,6 +76,10 @@ pub enum OpType { Or, Pad, Pow, + + #[cfg(feature = "random")] + RandomUniform(RandomUniform), + Range, Reciprocal, ReduceMax(ReduceMax), @@ -552,6 +559,20 @@ impl<'a> ModelBuilder<'a> { } OpType::Pad => op!(Pad), OpType::Pow => op!(Pow), + + #[cfg(feature = "random")] + OpType::RandomUniform(args) => { + let shape = self.create_vec(Some(args.shape), |size| size as u32); + op_with_attrs!(RandomUniform, RandomUniformAttrs, { + sg::RandomUniformAttrsArgs { + high: args.high, + low: args.low, + seed: args.seed, + shape, + } + }) + } + OpType::Range => op!(Range), OpType::Reciprocal => op!(Reciprocal), OpType::ReduceMax(args) => { diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 427d03ea..49a72a3b 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -34,6 +34,10 @@ mod non_max_suppression; mod norm; mod pad; mod pooling; + +#[cfg(feature = "random")] +mod random; + mod reduce; mod resize; mod rnn; @@ -72,6 +76,10 @@ pub use pad::{pad, Pad}; pub use pooling::{ average_pool, global_average_pool, max_pool, AveragePool, GlobalAveragePool, MaxPool, }; + +#[cfg(feature = "random")] +pub use random::RandomUniform; + pub use reduce::{ arg_max, arg_min, cum_sum, nonzero, reduce_l2, reduce_max, reduce_mean, reduce_min, reduce_prod, reduce_sum, reduce_sum_square, topk, ArgMax, ArgMin, CumSum, NonZero, ReduceL2, @@ -671,6 +679,11 @@ pub struct InputList<'a> { } impl<'a> InputList<'a> { + /// Construct an empty input list. + pub fn new() -> InputList<'static> { + InputList { inputs: vec![] } + } + pub fn from<'b>(inputs: &[Input<'b>]) -> InputList<'b> { InputList { inputs: inputs.iter().cloned().map(Some).collect(), diff --git a/src/ops/random.rs b/src/ops/random.rs new file mode 100644 index 00000000..ece44e2a --- /dev/null +++ b/src/ops/random.rs @@ -0,0 +1,147 @@ +use fastrand::Rng; +use rten_tensor::Tensor; + +use crate::ops::{InputList, IntoOpResult, OpError, Operator, Output}; + +#[derive(Debug)] +pub struct RandomUniform { + pub low: f32, + pub high: f32, + pub shape: Vec, + + /// Random seed. + /// + /// This unusually uses an `f32` value for consistency with the ONNX + /// specification. + pub seed: Option, +} + +impl Operator for RandomUniform { + fn name(&self) -> &str { + "RandomUniform" + } + + fn run(&self, _inputs: InputList) -> Result, OpError> { + let scale_value = |val: f32| self.low + val * (self.high - self.low); + let shape = self.shape.as_slice(); + + let mut rng = if let Some(seed) = self.seed { + Rng::with_seed(seed.to_bits() as u64) + } else { + Rng::new() + }; + + Tensor::from_simple_fn(shape, || scale_value(rng.f32())).into_op_result() + } +} + +#[cfg(test)] +mod tests { + use rten_tensor::prelude::*; + use rten_tensor::Tensor; + + use crate::ops::{InputList, Operator}; + + use super::RandomUniform; + + #[test] + fn test_random_uniform() { + struct Case { + low: f32, + high: f32, + shape: Vec, + seed: Option, + } + + let cases = [ + // Standard value range. + Case { + low: 0., + high: 1., + shape: vec![50, 50], + seed: None, + }, + // Non-standard low/high ranges. + Case { + low: -5., + high: -1., + shape: vec![50, 50], + seed: None, + }, + Case { + low: 1., + high: 5., + shape: vec![50, 50], + seed: None, + }, + // Custom seed + Case { + low: 0., + high: 1., + shape: vec![50, 50], + seed: Some(0.5), + }, + ]; + + for Case { + low, + high, + shape, + seed, + } in cases + { + let op = RandomUniform { + low, + high, + shape, + seed, + }; + let output = op.run(InputList::new()).unwrap().remove(0); + let output: Tensor = output.try_into().unwrap(); + + assert_eq!(output.shape(), op.shape); + + // Create buckets to count elements in N sub-intervals of + // `[op.low, op.high]`. + let mut buckets = vec![0; 10]; + let bucket_size = (op.high - op.low) as f32 / buckets.len() as f32; + + // Test generated outputs are within expected range. + for el in output.iter().copied() { + let low = op.low; + let high = op.high; + assert!( + el >= low && el <= high, + "value {el} outside range {low}..{high}" + ); + + let bucket_idx = ((el - low) / bucket_size) as usize; + buckets[bucket_idx] += 1; + } + + // Check that distribution is approximately uniform. A more + // principled approach would be to do a chi-squared test. + let expected_count_per_bucket = (output.len() / buckets.len()) as i32; + let max_expected_count_diff = buckets + .iter() + .map(|count| (count - expected_count_per_bucket).abs()) + .max() + .unwrap(); + let tolerance = (expected_count_per_bucket as f32) * 0.3; + assert!( + (max_expected_count_diff as f32) <= tolerance, + "max deviation from expected bucket size {max_expected_count_diff} > tolerance {tolerance}" + ); + + // Test that repeated generation produces the same output if the + // seed is fixed, or different output otherwise. + let output_2 = op.run(InputList::new()).unwrap().remove(0); + let output_2: Tensor = output_2.try_into().unwrap(); + if let Some(_seed) = seed { + assert_eq!(output, output_2); + } else { + assert_ne!(output, output_2); + } + } + } +} diff --git a/src/schema.fbs b/src/schema.fbs index 4b70e0e3..8f7769b0 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -105,6 +105,7 @@ enum OperatorType: ubyte { GatherElements, LayerNormalization, ReduceSumSquare, + RandomUniform, } enum RNNDirection: ubyte { @@ -178,6 +179,7 @@ union OperatorAttrs { ScatterNDAttrs, NonMaxSuppressionAttrs, LayerNormalizationAttrs, + RandomUniformAttrs, } table ArgMaxAttrs { @@ -308,6 +310,13 @@ table OneHotAttrs { axis:int; } +table RandomUniformAttrs { + shape:[uint]; + high:float; + low:float; + seed:float = null; +} + table ReduceMeanAttrs { axes:[int]; keep_dims:bool; diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 80cb11ed..f4d2a5d5 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -18,13 +18,13 @@ pub const ENUM_MIN_OPERATOR_TYPE: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_OPERATOR_TYPE: u8 = 94; +pub const ENUM_MAX_OPERATOR_TYPE: u8 = 95; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 95] = [ +pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 96] = [ OperatorType::Add, OperatorType::ArgMin, OperatorType::ArgMax, @@ -120,6 +120,7 @@ pub const ENUM_VALUES_OPERATOR_TYPE: [OperatorType; 95] = [ OperatorType::GatherElements, OperatorType::LayerNormalization, OperatorType::ReduceSumSquare, + OperatorType::RandomUniform, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -222,9 +223,10 @@ impl OperatorType { pub const GatherElements: Self = Self(92); pub const LayerNormalization: Self = Self(93); pub const ReduceSumSquare: Self = Self(94); + pub const RandomUniform: Self = Self(95); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 94; + pub const ENUM_MAX: u8 = 95; pub const ENUM_VALUES: &'static [Self] = &[ Self::Add, Self::ArgMin, @@ -321,6 +323,7 @@ impl OperatorType { Self::GatherElements, Self::LayerNormalization, Self::ReduceSumSquare, + Self::RandomUniform, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -420,6 +423,7 @@ impl OperatorType { Self::GatherElements => Some("GatherElements"), Self::LayerNormalization => Some("LayerNormalization"), Self::ReduceSumSquare => Some("ReduceSumSquare"), + Self::RandomUniform => Some("RandomUniform"), _ => None, } } @@ -1046,13 +1050,13 @@ pub const ENUM_MIN_OPERATOR_ATTRS: u8 = 0; since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] -pub const ENUM_MAX_OPERATOR_ATTRS: u8 = 30; +pub const ENUM_MAX_OPERATOR_ATTRS: u8 = 31; #[deprecated( since = "2.0.0", note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_OPERATOR_ATTRS: [OperatorAttrs; 31] = [ +pub const ENUM_VALUES_OPERATOR_ATTRS: [OperatorAttrs; 32] = [ OperatorAttrs::NONE, OperatorAttrs::ArgMaxAttrs, OperatorAttrs::AveragePoolAttrs, @@ -1084,6 +1088,7 @@ pub const ENUM_VALUES_OPERATOR_ATTRS: [OperatorAttrs; 31] = [ OperatorAttrs::ScatterNDAttrs, OperatorAttrs::NonMaxSuppressionAttrs, OperatorAttrs::LayerNormalizationAttrs, + OperatorAttrs::RandomUniformAttrs, ]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] @@ -1122,9 +1127,10 @@ impl OperatorAttrs { pub const ScatterNDAttrs: Self = Self(28); pub const NonMaxSuppressionAttrs: Self = Self(29); pub const LayerNormalizationAttrs: Self = Self(30); + pub const RandomUniformAttrs: Self = Self(31); pub const ENUM_MIN: u8 = 0; - pub const ENUM_MAX: u8 = 30; + pub const ENUM_MAX: u8 = 31; pub const ENUM_VALUES: &'static [Self] = &[ Self::NONE, Self::ArgMaxAttrs, @@ -1157,6 +1163,7 @@ impl OperatorAttrs { Self::ScatterNDAttrs, Self::NonMaxSuppressionAttrs, Self::LayerNormalizationAttrs, + Self::RandomUniformAttrs, ]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { @@ -1192,6 +1199,7 @@ impl OperatorAttrs { Self::ScatterNDAttrs => Some("ScatterNDAttrs"), Self::NonMaxSuppressionAttrs => Some("NonMaxSuppressionAttrs"), Self::LayerNormalizationAttrs => Some("LayerNormalizationAttrs"), + Self::RandomUniformAttrs => Some("RandomUniformAttrs"), _ => None, } } @@ -4639,6 +4647,184 @@ impl core::fmt::Debug for OneHotAttrs<'_> { ds.finish() } } +pub enum RandomUniformAttrsOffset {} +#[derive(Copy, Clone, PartialEq)] + +pub struct RandomUniformAttrs<'a> { + pub _tab: flatbuffers::Table<'a>, +} + +impl<'a> flatbuffers::Follow<'a> for RandomUniformAttrs<'a> { + type Inner = RandomUniformAttrs<'a>; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { + _tab: flatbuffers::Table::new(buf, loc), + } + } +} + +impl<'a> RandomUniformAttrs<'a> { + pub const VT_SHAPE: flatbuffers::VOffsetT = 4; + pub const VT_HIGH: flatbuffers::VOffsetT = 6; + pub const VT_LOW: flatbuffers::VOffsetT = 8; + pub const VT_SEED: flatbuffers::VOffsetT = 10; + + #[inline] + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + RandomUniformAttrs { _tab: table } + } + #[allow(unused_mut)] + pub fn create<'bldr: 'args, 'args: 'mut_bldr, 'mut_bldr>( + _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr>, + args: &'args RandomUniformAttrsArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = RandomUniformAttrsBuilder::new(_fbb); + if let Some(x) = args.seed { + builder.add_seed(x); + } + builder.add_low(args.low); + builder.add_high(args.high); + if let Some(x) = args.shape { + builder.add_shape(x); + } + builder.finish() + } + + #[inline] + pub fn shape(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>>( + RandomUniformAttrs::VT_SHAPE, + None, + ) + } + } + #[inline] + pub fn high(&self) -> f32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(RandomUniformAttrs::VT_HIGH, Some(0.0)) + .unwrap() + } + } + #[inline] + pub fn low(&self) -> f32 { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::(RandomUniformAttrs::VT_LOW, Some(0.0)) + .unwrap() + } + } + #[inline] + pub fn seed(&self) -> Option { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { self._tab.get::(RandomUniformAttrs::VT_SEED, None) } + } +} + +impl flatbuffers::Verifiable for RandomUniformAttrs<'_> { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use self::flatbuffers::Verifiable; + v.visit_table(pos)? + .visit_field::>>( + "shape", + Self::VT_SHAPE, + false, + )? + .visit_field::("high", Self::VT_HIGH, false)? + .visit_field::("low", Self::VT_LOW, false)? + .visit_field::("seed", Self::VT_SEED, false)? + .finish(); + Ok(()) + } +} +pub struct RandomUniformAttrsArgs<'a> { + pub shape: Option>>, + pub high: f32, + pub low: f32, + pub seed: Option, +} +impl<'a> Default for RandomUniformAttrsArgs<'a> { + #[inline] + fn default() -> Self { + RandomUniformAttrsArgs { + shape: None, + high: 0.0, + low: 0.0, + seed: None, + } + } +} + +pub struct RandomUniformAttrsBuilder<'a: 'b, 'b> { + fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a>, + start_: flatbuffers::WIPOffset, +} +impl<'a: 'b, 'b> RandomUniformAttrsBuilder<'a, 'b> { + #[inline] + pub fn add_shape(&mut self, shape: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(RandomUniformAttrs::VT_SHAPE, shape); + } + #[inline] + pub fn add_high(&mut self, high: f32) { + self.fbb_ + .push_slot::(RandomUniformAttrs::VT_HIGH, high, 0.0); + } + #[inline] + pub fn add_low(&mut self, low: f32) { + self.fbb_ + .push_slot::(RandomUniformAttrs::VT_LOW, low, 0.0); + } + #[inline] + pub fn add_seed(&mut self, seed: f32) { + self.fbb_ + .push_slot_always::(RandomUniformAttrs::VT_SEED, seed); + } + #[inline] + pub fn new( + _fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>, + ) -> RandomUniformAttrsBuilder<'a, 'b> { + let start = _fbb.start_table(); + RandomUniformAttrsBuilder { + fbb_: _fbb, + start_: start, + } + } + #[inline] + pub fn finish(self) -> flatbuffers::WIPOffset> { + let o = self.fbb_.end_table(self.start_); + flatbuffers::WIPOffset::new(o.value()) + } +} + +impl core::fmt::Debug for RandomUniformAttrs<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut ds = f.debug_struct("RandomUniformAttrs"); + ds.field("shape", &self.shape()); + ds.field("high", &self.high()); + ds.field("low", &self.low()); + ds.field("seed", &self.seed()); + ds.finish() + } +} pub enum ReduceMeanAttrsOffset {} #[derive(Copy, Clone, PartialEq)] @@ -6384,6 +6570,21 @@ impl<'a> OperatorNode<'a> { None } } + + #[inline] + #[allow(non_snake_case)] + pub fn attrs_as_random_uniform_attrs(&self) -> Option> { + if self.attrs_type() == OperatorAttrs::RandomUniformAttrs { + self.attrs().map(|t| { + // Safety: + // Created from a valid Table for this object + // Which contains a valid union in this slot + unsafe { RandomUniformAttrs::init_from_table(t) } + }) + } else { + None + } + } } impl flatbuffers::Verifiable for OperatorNode<'_> { @@ -6427,6 +6628,7 @@ impl flatbuffers::Verifiable for OperatorNode<'_> { OperatorAttrs::ScatterNDAttrs => v.verify_union_variant::>("OperatorAttrs::ScatterNDAttrs", pos), OperatorAttrs::NonMaxSuppressionAttrs => v.verify_union_variant::>("OperatorAttrs::NonMaxSuppressionAttrs", pos), OperatorAttrs::LayerNormalizationAttrs => v.verify_union_variant::>("OperatorAttrs::LayerNormalizationAttrs", pos), + OperatorAttrs::RandomUniformAttrs => v.verify_union_variant::>("OperatorAttrs::RandomUniformAttrs", pos), _ => Ok(()), } })? @@ -6810,6 +7012,16 @@ impl core::fmt::Debug for OperatorNode<'_> { ) } } + OperatorAttrs::RandomUniformAttrs => { + if let Some(x) = self.attrs_as_random_uniform_attrs() { + ds.field("attrs", &x) + } else { + ds.field( + "attrs", + &"InvalidFlatbuffer: Union discriminant does not match value.", + ) + } + } _ => { let x: Option<()> = None; ds.field("attrs", &x) From 68f3499b63691c3a809267240042d25680079b0a Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 30 Mar 2024 07:38:16 +0000 Subject: [PATCH 3/3] Document the `random` crate feature --- src/lib.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 789ac669..7f65adb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ //! rten is a runtime for machine learning models. //! -//! RTen uses models that are exported from other frameworks such as -//! PyTorch into [ONNX](https://onnx.ai) format and then converted into the +//! RTen uses models that are exported from other frameworks such as PyTorch +//! into [ONNX](https://onnx.ai) format and then converted into the //! inference-optimized `.rten` format by the tools in this repository. //! +//! # Loading and running models +//! //! The basic workflow for loading and running a model is: //! //! 1. Load the model using [Model::load]. @@ -18,7 +20,21 @@ //! See the example projects in [rten-examples][rten_examples] to see how all //! these pieces fit together. //! +//! # Supported operators +//! +//! RTen currently implements a subset of [ONNX operators][onnx_operators]. See +//! the [`schema.fbs` FlatBuffers schema][schema_fbs] for currently supported +//! operators and attributes. +//! +//! Some operators require additional dependencies and are only available if +//! certain crate features are enabled: +//! +//! - The `random` feature enables operators that generate random numbers (eg. +//! `RandomUniform`). +//! //! [rten_examples]: https://github.com/robertknight/rten/tree/main/rten-examples +//! [onnx_operators]: https://onnx.ai/onnx/operators/ +//! [schema_fbs]: https://github.com/robertknight/rten/blob/main/src/schema.fbs #![cfg_attr( feature = "avx512", feature(stdarch_x86_avx512),