Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RandomUniform operator #69

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ description = "Machine learning runtime"
license = "MIT OR Apache-2.0"
homepage = "https://github.com/robertknight/rten"
repository = "https://github.com/robertknight/rten"
resolver = "2"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -37,8 +38,10 @@ rayon = "1.7.0"
smallvec = { version = "1.10.0", features = ["union", "const_generics", "const_new"] }
rten-tensor = { path = "./rten-tensor", version = "0.4.0" }
rten-vecmath = { path = "./rten-vecmath", version = "0.4.0" }
fastrand = { version = "2.0.2", optional = true }

[dev-dependencies]
rten = { path = ".", features = ["random"] }
rten-bench = { path = "./rten-bench" }
serde_json = "1.0.91"

Expand All @@ -50,6 +53,8 @@ crate-type = ["lib", "cdylib"]
avx512 = ["rten-vecmath/avx512"]
# Generate WebAssembly API using wasm-bindgen.
wasm_api = []
# Enable operators that generate random numbers.
random = ["fastrand"]

[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen = "0.2.83"
Expand Down
4 changes: 2 additions & 2 deletions rten-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ homepage = "https://github.com/robertknight/rten"
repository = "https://github.com/robertknight/rten"

[dependencies]
fastrand = "1.9.0"
rten = { path = "../", version = "0.5.0" }
fastrand = "2.0.2"
rten = { path = "../", version = "0.5.0", features=["random"] }
rten-tensor = { path = "../rten-tensor", version = "0.4.0" }
lexopt = "0.3.0"

Expand Down
9 changes: 9 additions & 0 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,15 @@ def op_node_from_onnx_operator(
attrs = sg.OneHotAttrsT()
attrs.axis = op_reader.get_attr("axis", "int", -1)

case "RandomUniform":
attrs = sg.RandomUniformAttrsT()
op_reader.check_attr("dtype", "int", 1)

attrs.seed = op_reader.get_attr("seed", "float", None)
attrs.shape = op_reader.require_attr("shape", "ints")
attrs.low = op_reader.get_attr("low", "float", 0.0)
attrs.high = op_reader.get_attr("high", "float", 1.0)

case (
"ReduceL2"
| "ReduceMax"
Expand Down
164 changes: 163 additions & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class OperatorType(object):
GatherElements = 92
LayerNormalization = 93
ReduceSumSquare = 94
RandomUniform = 95


class RNNDirection(object):
Expand Down Expand Up @@ -170,6 +171,7 @@ class OperatorAttrs(object):
ScatterNDAttrs = 28
NonMaxSuppressionAttrs = 29
LayerNormalizationAttrs = 30
RandomUniformAttrs = 31

def OperatorAttrsCreator(unionType, table):
from flatbuffers.table import Table
Expand Down Expand Up @@ -235,6 +237,8 @@ def OperatorAttrsCreator(unionType, table):
return NonMaxSuppressionAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LayerNormalizationAttrs:
return LayerNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformAttrs:
return RandomUniformAttrsT.InitFromBuf(table.Bytes, table.Pos)
return None


Expand Down Expand Up @@ -2675,6 +2679,164 @@ def Pack(self, builder):
return oneHotAttrs


class RandomUniformAttrs(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = RandomUniformAttrs()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsRandomUniformAttrs(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def RandomUniformAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed)

# RandomUniformAttrs
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# RandomUniformAttrs
def Shape(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0

# RandomUniformAttrs
def ShapeAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0

# RandomUniformAttrs
def ShapeLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0

# RandomUniformAttrs
def ShapeIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0

# RandomUniformAttrs
def High(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
return 0.0

# RandomUniformAttrs
def Low(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
return 0.0

# RandomUniformAttrs
def Seed(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
return None

def RandomUniformAttrsStart(builder):
builder.StartObject(4)

def RandomUniformAttrsAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)

def RandomUniformAttrsStartShapeVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def RandomUniformAttrsAddHigh(builder, high):
builder.PrependFloat32Slot(1, high, 0.0)

def RandomUniformAttrsAddLow(builder, low):
builder.PrependFloat32Slot(2, low, 0.0)

def RandomUniformAttrsAddSeed(builder, seed):
builder.PrependFloat32Slot(3, seed, None)

def RandomUniformAttrsEnd(builder):
return builder.EndObject()


try:
from typing import List
except:
pass

class RandomUniformAttrsT(object):

# RandomUniformAttrsT
def __init__(self):
self.shape = None # type: List[int]
self.high = 0.0 # type: float
self.low = 0.0 # type: float
self.seed = None # type: Optional[float]

@classmethod
def InitFromBuf(cls, buf, pos):
randomUniformAttrs = RandomUniformAttrs()
randomUniformAttrs.Init(buf, pos)
return cls.InitFromObj(randomUniformAttrs)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, randomUniformAttrs):
x = RandomUniformAttrsT()
x._UnPack(randomUniformAttrs)
return x

# RandomUniformAttrsT
def _UnPack(self, randomUniformAttrs):
if randomUniformAttrs is None:
return
if not randomUniformAttrs.ShapeIsNone():
if np is None:
self.shape = []
for i in range(randomUniformAttrs.ShapeLength()):
self.shape.append(randomUniformAttrs.Shape(i))
else:
self.shape = randomUniformAttrs.ShapeAsNumpy()
self.high = randomUniformAttrs.High()
self.low = randomUniformAttrs.Low()
self.seed = randomUniformAttrs.Seed()

# RandomUniformAttrsT
def Pack(self, builder):
if self.shape is not None:
if np is not None and type(self.shape) is np.ndarray:
shape = builder.CreateNumpyVector(self.shape)
else:
RandomUniformAttrsStartShapeVector(builder, len(self.shape))
for i in reversed(range(len(self.shape))):
builder.PrependUint32(self.shape[i])
shape = builder.EndVector()
RandomUniformAttrsStart(builder)
if self.shape is not None:
RandomUniformAttrsAddShape(builder, shape)
RandomUniformAttrsAddHigh(builder, self.high)
RandomUniformAttrsAddLow(builder, self.low)
RandomUniformAttrsAddSeed(builder, self.seed)
randomUniformAttrs = RandomUniformAttrsEnd(builder)
return randomUniformAttrs


class ReduceMeanAttrs(object):
__slots__ = ['_tab']

Expand Down Expand Up @@ -3746,7 +3908,7 @@ class OperatorNodeT(object):
def __init__(self):
self.type = 0 # type: int
self.attrsType = 0 # type: int
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT]
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT]
self.inputs = None # type: List[int]
self.outputs = None # type: List[int]

Expand Down
2 changes: 1 addition & 1 deletion rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ homepage = "https://github.com/robertknight/rten"
repository = "https://github.com/robertknight/rten"

[dependencies]
fastrand = "1.9.0"
fastrand = "2.0.2"
hound = "3.5.1"
image = { version = "0.24.6", default-features = false, features = ["png", "jpeg", "jpeg_rayon", "webp"] }
lexopt = "0.3.0"
Expand Down
2 changes: 1 addition & 1 deletion rten-vecmath/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ homepage = "https://github.com/robertknight/rten"
repository = "https://github.com/robertknight/rten"

[dev-dependencies]
fastrand = "1.9.0"
fastrand = "2.0.2"
libm = "0.2.6"

[lib]
Expand Down
22 changes: 19 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! rten is a runtime for machine learning models.
//!
//! RTen uses models that are exported from other frameworks such as
//! PyTorch into [ONNX](https://onnx.ai) format and then converted into the
//! RTen uses models that are exported from other frameworks such as PyTorch
//! into [ONNX](https://onnx.ai) format and then converted into the
//! inference-optimized `.rten` format by the tools in this repository.
//!
//! # Loading and running models
//!
//! The basic workflow for loading and running a model is:
//!
//! 1. Load the model using [Model::load].
Expand All @@ -18,7 +20,21 @@
//! See the example projects in [rten-examples][rten_examples] to see how all
//! these pieces fit together.
//!
//! # Supported operators
//!
//! RTen currently implements a subset of [ONNX operators][onnx_operators]. See
//! the [`schema.fbs` FlatBuffers schema][schema_fbs] for currently supported
//! operators and attributes.
//!
//! Some operators require additional dependencies and are only available if
//! certain crate features are enabled:
//!
//! - The `random` feature enables operators that generate random numbers (eg.
//! `RandomUniform`).
//!
//! [rten_examples]: https://github.com/robertknight/rten/tree/main/rten-examples
//! [onnx_operators]: https://onnx.ai/onnx/operators/
//! [schema_fbs]: https://github.com/robertknight/rten/blob/main/src/schema.fbs
#![cfg_attr(
feature = "avx512",
feature(stdarch_x86_avx512),
Expand Down Expand Up @@ -48,7 +64,7 @@ pub mod ctc;
pub mod ops;

pub use graph::{Dimension, NodeId, RunOptions};
pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry};
pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry, ReadOpError};
pub use model_metadata::ModelMetadata;
pub use ops::{FloatOperators, Input, Operators, Output};
pub use timer::Timer;
Expand Down
Loading
Loading