diff --git a/rten-cli/src/main.rs b/rten-cli/src/main.rs index 051afb0d..c1e58a27 100644 --- a/rten-cli/src/main.rs +++ b/rten-cli/src/main.rs @@ -3,7 +3,7 @@ use std::error::Error; use std::fs; use std::time::Instant; -use rten::{Dimension, Input, Model, NodeId, Output, RunOptions}; +use rten::{Dimension, Input, Model, ModelMetadata, NodeId, Output, RunOptions}; use rten_tensor::prelude::*; use rten_tensor::Tensor; @@ -37,6 +37,11 @@ fn parse_args() -> Result { Usage: {bin_name} [OPTIONS] +Args: + + Path to '.rten' model to inspect and run. + +Options: -t, --timing Output timing info -v, --verbose Enable verbose logging -h, --help Print help @@ -66,28 +71,27 @@ fn format_param_count(n: usize) -> String { } } -/// Tool for inspecting converted ONNX models and running them with randomly -/// generated inputs. -/// -/// ``` -/// tools/convert-onnx.py model.onnx output.rten -/// cargo run -p rten-cli --release output.rten -/// ``` -/// -/// To get detailed timing information set the `RTEN_TIMING` env var before -/// running. See `docs/profiling.md`. -fn main() -> Result<(), Box> { - let args = parse_args()?; - let model_bytes = fs::read(args.model)?; - let model = Model::load(&model_bytes)?; +fn print_metadata(metadata: &ModelMetadata) { + fn print_field(name: &str, value: Option) { + if let Some(value) = value { + println!(" {}: {}", name, value); + } + } - println!( - "Model stats: {} inputs, {} outputs, {} params", - model.input_ids().len(), - model.output_ids().len(), - format_param_count(model.total_params()), - ); + println!("Metadata:"); + print_field("ONNX hash", metadata.onnx_hash()); + print_field("Description", metadata.description()); + print_field("License", metadata.license()); + print_field("Commit", metadata.commit()); + print_field("Repository", metadata.code_repository()); + print_field("Model repository", metadata.model_repository()); + print_field("Run ID", metadata.run_id()); + print_field("Run URL", metadata.run_url()); +} +/// Generate random inputs for `model` using shape metadata and heuristics, +/// run it, and print details of the output. +fn run_with_random_input(model: &Model, run_opts: RunOptions) -> Result<(), Box> { let mut rng = fastrand::Rng::new(); // Generate random ints that are likely to be valid token IDs in a language @@ -165,27 +169,21 @@ fn main() -> Result<(), Box> { .as_ref() .and_then(|ni| ni.name()) .unwrap_or("(unnamed)"); - println!("Input \"{name}\" resolved shape {:?}", input.shape()); + println!(" Input \"{name}\" generated shape {:?}", input.shape()); } // Run model and summarize outputs. let start = Instant::now(); - let outputs = model.run( - &inputs, - model.output_ids(), - Some(RunOptions { - timing: args.timing, - verbose: args.verbose, - ..Default::default() - }), - )?; + let outputs = model.run(&inputs, model.output_ids(), Some(run_opts))?; let elapsed = start.elapsed().as_millis(); + println!(); println!( - "Model returned {} outputs in {:.2}ms", + " Model returned {} outputs in {:.2}ms.", outputs.len(), elapsed ); + println!(); let output_names: Vec = model .output_ids() @@ -204,7 +202,7 @@ fn main() -> Result<(), Box> { Output::IntTensor(_) => "i32", }; println!( - "Output {i} \"{name}\" data type {} shape: {:?}", + " Output {i} \"{name}\" data type {} shape: {:?}", dtype, output.shape() ); @@ -212,3 +210,42 @@ fn main() -> Result<(), Box> { Ok(()) } + +/// Tool for inspecting converted ONNX models and running them with randomly +/// generated inputs. +/// +/// ``` +/// tools/convert-onnx.py model.onnx output.rten +/// cargo run -p rten-cli --release output.rten +/// ``` +/// +/// To get detailed timing information set the `RTEN_TIMING` env var before +/// running. See `docs/profiling.md`. +fn main() -> Result<(), Box> { + let args = parse_args()?; + let model_bytes = fs::read(args.model)?; + let model = Model::load(&model_bytes)?; + + println!( + "Model summary: {} inputs, {} outputs, {} params", + model.input_ids().len(), + model.output_ids().len(), + format_param_count(model.total_params()), + ); + println!(); + + print_metadata(model.metadata()); + + println!(); + println!("Running model with random inputs..."); + run_with_random_input( + &model, + RunOptions { + timing: args.timing, + verbose: args.verbose, + ..Default::default() + }, + )?; + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 599e667d..fe5b3fb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,6 +28,7 @@ mod gemm; mod graph; mod iter_util; mod model; +mod model_metadata; mod number; mod slice_reductions; mod timer; @@ -44,6 +45,7 @@ pub mod ops; pub use graph::{Dimension, NodeId, RunOptions}; pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry}; +pub use model_metadata::ModelMetadata; pub use ops::{FloatOperators, Input, Operators, Output}; pub use timer::Timer; pub use timing::TimingSort; diff --git a/src/model.rs b/src/model.rs index 1d866b3f..1c28150f 100644 --- a/src/model.rs +++ b/src/model.rs @@ -9,6 +9,7 @@ use rten_tensor::Tensor; use smallvec::smallvec; use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions}; +use crate::model_metadata::ModelMetadata; use crate::ops; use crate::ops::{ BoxOrder, CoordTransformMode, DataType, Direction, Input, NearestMode, Operator, Output, @@ -52,6 +53,7 @@ pub struct Model { input_ids: Vec, output_ids: Vec, graph: Graph, + metadata: ModelMetadata, } /// Provides access to metadata about a graph node. @@ -144,6 +146,11 @@ impl Model { self.graph.get_node(id).map(|node| NodeInfo { node }) } + /// Return metadata about the model. + pub fn metadata(&self) -> &ModelMetadata { + &self.metadata + } + /// Return the IDs of input nodes. pub fn input_ids(&self) -> &[NodeId] { &self.input_ids @@ -1108,11 +1115,17 @@ fn load_model(data: &[u8], registry: &OpRegistry) -> Result { nodes: Vec>>, input_ids: Vec, output_ids: Vec, + metadata: Option>>, } enum NodeData<'a> { @@ -126,6 +127,11 @@ enum NodeData<'a> { Operator(WIPOffset>), } +/// Arguments for [ModelBuilder::add_metadata]. +pub struct MetadataArgs { + pub onnx_hash: Option, +} + struct PadArgs { pad_mode: sg::PadMode, pads: Option>, @@ -152,6 +158,7 @@ impl<'a> ModelBuilder<'a> { nodes: Vec::new(), input_ids: Vec::new(), output_ids: Vec::new(), + metadata: None, } } @@ -683,6 +690,19 @@ impl<'a> ModelBuilder<'a> { self.output_ids.push(node_id); } + /// Add model metadata + pub fn add_metadata(&mut self, metadata: MetadataArgs) { + let hash = metadata + .onnx_hash + .as_ref() + .map(|hash| self.builder.create_string(hash)); + let mut meta_builder = sg::MetadataBuilder::new(&mut self.builder); + if let Some(hash) = hash { + meta_builder.add_onnx_hash(hash); + } + self.metadata = Some(meta_builder.finish()); + } + /// Finish writing the model data to the buffer and return the buffer's contents. pub fn finish(mut self) -> Vec { let inputs_vec = self.builder.create_vector(&self.input_ids[..]); @@ -703,6 +723,7 @@ impl<'a> ModelBuilder<'a> { &sg::ModelArgs { schema_version: 1, graph: Some(graph), + metadata: self.metadata, }, ); diff --git a/src/model_metadata.rs b/src/model_metadata.rs new file mode 100644 index 00000000..88beb396 --- /dev/null +++ b/src/model_metadata.rs @@ -0,0 +1,148 @@ +use crate::schema_generated as sg; + +/// Metadata for an RTen model. +/// +/// This provides access to information such as: +/// +/// - The ONNX model that was used to generate it +/// - The license +/// - Details of the training run that produced the model +/// - Related URLs +#[derive(Default)] +pub struct ModelMetadata { + onnx_hash: Option, + description: Option, + license: Option, + commit: Option, + code_repository: Option, + model_repository: Option, + run_id: Option, + run_url: Option, +} + +impl ModelMetadata { + /// Deserialize a ModelMetadata from data in a flatbuffers file. + pub(crate) fn deserialize(metadata: sg::Metadata<'_>) -> ModelMetadata { + ModelMetadata { + onnx_hash: metadata.onnx_hash().map(|s| s.to_string()), + description: metadata.description().map(|s| s.to_string()), + license: metadata.license().map(|s| s.to_string()), + commit: metadata.commit().map(|s| s.to_string()), + code_repository: metadata.code_repository().map(|s| s.to_string()), + model_repository: metadata.model_repository().map(|s| s.to_string()), + run_id: metadata.run_id().map(|s| s.to_string()), + run_url: metadata.run_url().map(|s| s.to_string()), + } + } + + /// Return the SHA-256 hash of the ONNX model used to generate this RTen + /// model. + pub fn onnx_hash(&self) -> Option<&str> { + self.onnx_hash.as_deref() + } + + /// Return a short description of what this model does. + pub fn description(&self) -> Option<&str> { + self.description.as_deref() + } + + /// Return the license identifier for this model. It is recommended that + /// this be an SPDX identifier. + pub fn license(&self) -> Option<&str> { + self.license.as_deref() + } + + /// Return the commit from the repository referenced by + /// [code_repository](ModelMetadata::code_repository) which was used to + /// create this model. + pub fn commit(&self) -> Option<&str> { + self.commit.as_deref() + } + + /// Return the URL of the repository (eg. on GitHub) containing the model's + /// code. + pub fn code_repository(&self) -> Option<&str> { + self.code_repository.as_deref() + } + + /// Return the URL of the repository (eg. on Hugging Face) where the model + /// is hosted. + pub fn model_repository(&self) -> Option<&str> { + self.model_repository.as_deref() + } + + /// Return the ID of the training run that produced this model. + /// + /// When models are developed using experiment tracking services such as + /// Weights and Biases, this enables looking up the training run that + /// produced the model. + pub fn run_id(&self) -> Option<&str> { + self.run_id.as_deref() + } + + /// Return a URL for the training run that produced this model. + /// + /// When models are developed using experiment tracking services such as + /// Weights and Biases, this enables looking up the training run that + /// produced the model. + pub fn run_url(&self) -> Option<&str> { + self.run_url.as_deref() + } +} + +#[cfg(test)] +mod tests { + use super::ModelMetadata; + use crate::schema_generated as sg; + use flatbuffers::FlatBufferBuilder; + + #[test] + fn test_model_metadata() { + let mut builder = FlatBufferBuilder::with_capacity(1024); + + let onnx_hash = builder.create_string("abc"); + let description = builder.create_string("A simple model"); + let license = builder.create_string("BSD-2-Clause"); + let commit = builder.create_string("def"); + let code_repository = builder.create_string("https://github.com/robertknight/rten"); + let model_repository = builder.create_string("https://huggingface.co/robertknight/rten"); + let run_id = builder.create_string("1234"); + let run_url = + builder.create_string("https://wandb.ai/robertknight/text-detection/runs/1234"); + + let mut meta_builder = sg::MetadataBuilder::new(&mut builder); + meta_builder.add_onnx_hash(onnx_hash); + meta_builder.add_description(description); + meta_builder.add_license(license); + meta_builder.add_commit(commit); + meta_builder.add_code_repository(code_repository); + meta_builder.add_model_repository(model_repository); + meta_builder.add_run_id(run_id); + meta_builder.add_run_url(run_url); + let metadata = meta_builder.finish(); + + builder.finish_minimal(metadata); + let data = builder.finished_data(); + + let deserialized_meta = flatbuffers::root::(&data).unwrap(); + let model_metadata = ModelMetadata::deserialize(deserialized_meta); + + assert_eq!(model_metadata.onnx_hash(), Some("abc")); + assert_eq!(model_metadata.description(), Some("A simple model")); + assert_eq!(model_metadata.license(), Some("BSD-2-Clause")); + assert_eq!(model_metadata.commit(), Some("def")); + assert_eq!( + model_metadata.code_repository(), + Some("https://github.com/robertknight/rten") + ); + assert_eq!( + model_metadata.model_repository(), + Some("https://huggingface.co/robertknight/rten") + ); + assert_eq!(model_metadata.run_id(), Some("1234")); + assert_eq!( + model_metadata.run_url(), + Some("https://wandb.ai/robertknight/text-detection/runs/1234") + ); + } +} diff --git a/src/schema.fbs b/src/schema.fbs index a6a63c95..83c093c5 100644 --- a/src/schema.fbs +++ b/src/schema.fbs @@ -437,9 +437,40 @@ table Graph { outputs:[uint]; } +table Metadata { + // SHA-256 hash of the ONNX model that was used as the source for this RTen + // model. + onnx_hash:string; + + // A short description of what this model does. + description:string; + + // Identifier for the license used in this model. + // + // This should be an SPDX (https://spdx.org/licenses/) identifier for openly + // licensed models. + license:string; + + // Commit ID for the code that produced this model. + commit:string; + + // URL of repository where the model's code is hosted (eg. GitHub). + code_repository:string; + + // URL of repository where the model is hosted (eg. Hugging Face). + model_repository:string; + + // Identifier for the training run that produced this model. + run_id:string; + + // URL of logs etc. for the training run that produced this model. + run_url:string; +} + table Model { schema_version:int; graph:Graph (required); + metadata:Metadata; } root_type Model; diff --git a/src/schema_generated.rs b/src/schema_generated.rs index 4d196921..13ae30e8 100644 --- a/src/schema_generated.rs +++ b/src/schema_generated.rs @@ -7943,6 +7943,292 @@ impl core::fmt::Debug for Graph<'_> { ds.finish() } } +pub enum MetadataOffset {} +#[derive(Copy, Clone, PartialEq)] + +pub struct Metadata<'a> { + pub _tab: flatbuffers::Table<'a>, +} + +impl<'a> flatbuffers::Follow<'a> for Metadata<'a> { + type Inner = Metadata<'a>; + #[inline] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { + _tab: flatbuffers::Table::new(buf, loc), + } + } +} + +impl<'a> Metadata<'a> { + pub const VT_ONNX_HASH: flatbuffers::VOffsetT = 4; + pub const VT_DESCRIPTION: flatbuffers::VOffsetT = 6; + pub const VT_LICENSE: flatbuffers::VOffsetT = 8; + pub const VT_COMMIT: flatbuffers::VOffsetT = 10; + pub const VT_CODE_REPOSITORY: flatbuffers::VOffsetT = 12; + pub const VT_MODEL_REPOSITORY: flatbuffers::VOffsetT = 14; + pub const VT_RUN_ID: flatbuffers::VOffsetT = 16; + pub const VT_RUN_URL: flatbuffers::VOffsetT = 18; + + #[inline] + pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { + Metadata { _tab: table } + } + #[allow(unused_mut)] + pub fn create<'bldr: 'args, 'args: 'mut_bldr, 'mut_bldr>( + _fbb: &'mut_bldr mut flatbuffers::FlatBufferBuilder<'bldr>, + args: &'args MetadataArgs<'args>, + ) -> flatbuffers::WIPOffset> { + let mut builder = MetadataBuilder::new(_fbb); + if let Some(x) = args.run_url { + builder.add_run_url(x); + } + if let Some(x) = args.run_id { + builder.add_run_id(x); + } + if let Some(x) = args.model_repository { + builder.add_model_repository(x); + } + if let Some(x) = args.code_repository { + builder.add_code_repository(x); + } + if let Some(x) = args.commit { + builder.add_commit(x); + } + if let Some(x) = args.license { + builder.add_license(x); + } + if let Some(x) = args.description { + builder.add_description(x); + } + if let Some(x) = args.onnx_hash { + builder.add_onnx_hash(x); + } + builder.finish() + } + + #[inline] + pub fn onnx_hash(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_ONNX_HASH, None) + } + } + #[inline] + pub fn description(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_DESCRIPTION, None) + } + } + #[inline] + pub fn license(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_LICENSE, None) + } + } + #[inline] + pub fn commit(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_COMMIT, None) + } + } + #[inline] + pub fn code_repository(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_CODE_REPOSITORY, None) + } + } + #[inline] + pub fn model_repository(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_MODEL_REPOSITORY, None) + } + } + #[inline] + pub fn run_id(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_RUN_ID, None) + } + } + #[inline] + pub fn run_url(&self) -> Option<&'a str> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Metadata::VT_RUN_URL, None) + } + } +} + +impl flatbuffers::Verifiable for Metadata<'_> { + #[inline] + fn run_verifier( + v: &mut flatbuffers::Verifier, + pos: usize, + ) -> Result<(), flatbuffers::InvalidFlatbuffer> { + use self::flatbuffers::Verifiable; + v.visit_table(pos)? + .visit_field::>( + "onnx_hash", + Self::VT_ONNX_HASH, + false, + )? + .visit_field::>( + "description", + Self::VT_DESCRIPTION, + false, + )? + .visit_field::>("license", Self::VT_LICENSE, false)? + .visit_field::>("commit", Self::VT_COMMIT, false)? + .visit_field::>( + "code_repository", + Self::VT_CODE_REPOSITORY, + false, + )? + .visit_field::>( + "model_repository", + Self::VT_MODEL_REPOSITORY, + false, + )? + .visit_field::>("run_id", Self::VT_RUN_ID, false)? + .visit_field::>("run_url", Self::VT_RUN_URL, false)? + .finish(); + Ok(()) + } +} +pub struct MetadataArgs<'a> { + pub onnx_hash: Option>, + pub description: Option>, + pub license: Option>, + pub commit: Option>, + pub code_repository: Option>, + pub model_repository: Option>, + pub run_id: Option>, + pub run_url: Option>, +} +impl<'a> Default for MetadataArgs<'a> { + #[inline] + fn default() -> Self { + MetadataArgs { + onnx_hash: None, + description: None, + license: None, + commit: None, + code_repository: None, + model_repository: None, + run_id: None, + run_url: None, + } + } +} + +pub struct MetadataBuilder<'a: 'b, 'b> { + fbb_: &'b mut flatbuffers::FlatBufferBuilder<'a>, + start_: flatbuffers::WIPOffset, +} +impl<'a: 'b, 'b> MetadataBuilder<'a, 'b> { + #[inline] + pub fn add_onnx_hash(&mut self, onnx_hash: flatbuffers::WIPOffset<&'b str>) { + self.fbb_ + .push_slot_always::>(Metadata::VT_ONNX_HASH, onnx_hash); + } + #[inline] + pub fn add_description(&mut self, description: flatbuffers::WIPOffset<&'b str>) { + self.fbb_ + .push_slot_always::>(Metadata::VT_DESCRIPTION, description); + } + #[inline] + pub fn add_license(&mut self, license: flatbuffers::WIPOffset<&'b str>) { + self.fbb_ + .push_slot_always::>(Metadata::VT_LICENSE, license); + } + #[inline] + pub fn add_commit(&mut self, commit: flatbuffers::WIPOffset<&'b str>) { + self.fbb_ + .push_slot_always::>(Metadata::VT_COMMIT, commit); + } + #[inline] + pub fn add_code_repository(&mut self, code_repository: flatbuffers::WIPOffset<&'b str>) { + self.fbb_.push_slot_always::>( + Metadata::VT_CODE_REPOSITORY, + code_repository, + ); + } + #[inline] + pub fn add_model_repository(&mut self, model_repository: flatbuffers::WIPOffset<&'b str>) { + self.fbb_.push_slot_always::>( + Metadata::VT_MODEL_REPOSITORY, + model_repository, + ); + } + #[inline] + pub fn add_run_id(&mut self, run_id: flatbuffers::WIPOffset<&'b str>) { + self.fbb_ + .push_slot_always::>(Metadata::VT_RUN_ID, run_id); + } + #[inline] + pub fn add_run_url(&mut self, run_url: flatbuffers::WIPOffset<&'b str>) { + self.fbb_ + .push_slot_always::>(Metadata::VT_RUN_URL, run_url); + } + #[inline] + pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>) -> MetadataBuilder<'a, 'b> { + let start = _fbb.start_table(); + MetadataBuilder { + fbb_: _fbb, + start_: start, + } + } + #[inline] + pub fn finish(self) -> flatbuffers::WIPOffset> { + let o = self.fbb_.end_table(self.start_); + flatbuffers::WIPOffset::new(o.value()) + } +} + +impl core::fmt::Debug for Metadata<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut ds = f.debug_struct("Metadata"); + ds.field("onnx_hash", &self.onnx_hash()); + ds.field("description", &self.description()); + ds.field("license", &self.license()); + ds.field("commit", &self.commit()); + ds.field("code_repository", &self.code_repository()); + ds.field("model_repository", &self.model_repository()); + ds.field("run_id", &self.run_id()); + ds.field("run_url", &self.run_url()); + ds.finish() + } +} pub enum ModelOffset {} #[derive(Copy, Clone, PartialEq)] @@ -7963,6 +8249,7 @@ impl<'a> flatbuffers::Follow<'a> for Model<'a> { impl<'a> Model<'a> { pub const VT_SCHEMA_VERSION: flatbuffers::VOffsetT = 4; pub const VT_GRAPH: flatbuffers::VOffsetT = 6; + pub const VT_METADATA: flatbuffers::VOffsetT = 8; #[inline] pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self { @@ -7974,6 +8261,9 @@ impl<'a> Model<'a> { args: &'args ModelArgs<'args>, ) -> flatbuffers::WIPOffset> { let mut builder = ModelBuilder::new(_fbb); + if let Some(x) = args.metadata { + builder.add_metadata(x); + } if let Some(x) = args.graph { builder.add_graph(x); } @@ -8003,6 +8293,16 @@ impl<'a> Model<'a> { .unwrap() } } + #[inline] + pub fn metadata(&self) -> Option> { + // Safety: + // Created from valid Table for this object + // which contains a valid value in this slot + unsafe { + self._tab + .get::>(Model::VT_METADATA, None) + } + } } impl flatbuffers::Verifiable for Model<'_> { @@ -8015,6 +8315,11 @@ impl flatbuffers::Verifiable for Model<'_> { v.visit_table(pos)? .visit_field::("schema_version", Self::VT_SCHEMA_VERSION, false)? .visit_field::>("graph", Self::VT_GRAPH, true)? + .visit_field::>( + "metadata", + Self::VT_METADATA, + false, + )? .finish(); Ok(()) } @@ -8022,6 +8327,7 @@ impl flatbuffers::Verifiable for Model<'_> { pub struct ModelArgs<'a> { pub schema_version: i32, pub graph: Option>>, + pub metadata: Option>>, } impl<'a> Default for ModelArgs<'a> { #[inline] @@ -8029,6 +8335,7 @@ impl<'a> Default for ModelArgs<'a> { ModelArgs { schema_version: 0, graph: None, // required field + metadata: None, } } } @@ -8049,6 +8356,11 @@ impl<'a: 'b, 'b> ModelBuilder<'a, 'b> { .push_slot_always::>(Model::VT_GRAPH, graph); } #[inline] + pub fn add_metadata(&mut self, metadata: flatbuffers::WIPOffset>) { + self.fbb_ + .push_slot_always::>(Model::VT_METADATA, metadata); + } + #[inline] pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>) -> ModelBuilder<'a, 'b> { let start = _fbb.start_table(); ModelBuilder { @@ -8069,6 +8381,7 @@ impl core::fmt::Debug for Model<'_> { let mut ds = f.debug_struct("Model"); ds.field("schema_version", &self.schema_version()); ds.field("graph", &self.graph()); + ds.field("metadata", &self.metadata()); ds.finish() } } diff --git a/tools/convert-onnx.py b/tools/convert-onnx.py index 48e53bcd..3aa30465 100755 --- a/tools/convert-onnx.py +++ b/tools/convert-onnx.py @@ -1,6 +1,9 @@ #!/usr/bin/env python from argparse import ArgumentParser +from dataclasses import dataclass +import hashlib +import json from os.path import splitext import sys from typing import Any, Callable, Literal, Optional, cast @@ -142,6 +145,28 @@ def __init__(self, nodes: list[Node], inputs: list[int], outputs: list[int]): self.outputs = outputs +@dataclass +class Metadata: + """ + Model metadata. + + This corresponds to the `ModelMetadata` struct in RTen. See its docs for + details of the individual fields. + + When adding new fields here, they also need to be added to + `METADATA_BUILDER_FNS`. + """ + + code_repository: Optional[str] = None + commit: Optional[str] = None + description: Optional[str] = None + license: Optional[str] = None + model_repository: Optional[str] = None + onnx_hash: Optional[str] = None + run_id: Optional[str] = None + run_url: Optional[str] = None + + # Mapping of ONNX attribute types to the field on an AttributeProto which # contains the value. Note that if you try to access the wrong field on an # AttributeProto, you get a default value instead of an exception. @@ -601,7 +626,11 @@ def op_node_from_onnx_operator( match to: case TensorProto.DataType.FLOAT: attrs.to = sg.DataType.Float - case TensorProto.DataType.BOOL | TensorProto.DataType.INT32 | TensorProto.DataType.INT64: + case ( + TensorProto.DataType.BOOL + | TensorProto.DataType.INT32 + | TensorProto.DataType.INT64 + ): attrs.to = sg.DataType.Int32 case _: raise Exception(f"Unsupported target type for cast {to}") @@ -769,7 +798,14 @@ def op_node_from_onnx_operator( attrs = sg.OneHotAttrsT() attrs.axis = op_reader.get_attr("axis", "int", -1) - case "ReduceL2" | "ReduceMax" | "ReduceMean" | "ReduceMin" | "ReduceProd" | "ReduceSum": + case ( + "ReduceL2" + | "ReduceMax" + | "ReduceMean" + | "ReduceMin" + | "ReduceProd" + | "ReduceSum" + ): attrs = sg.ReduceMeanAttrsT() attrs.axes = op_reader.get_attr("axes", "ints", None) attrs.keepDims = bool(op_reader.get_attr("keepdims", "int", 1)) @@ -1100,16 +1136,44 @@ def write_dim(builder, dim: str | int) -> int: return sg.ValueNodeEnd(builder) -def write_graph(graph: Graph, out_path: str): - """ - Serialize a model graph into a flatbuffers model. +METADATA_BUILDER_FNS = { + "code_repository": sg.MetadataAddCodeRepository, + "commit": sg.MetadataAddCommit, + "description": sg.MetadataAddDescription, + "license": sg.MetadataAddLicense, + "model_repository": sg.MetadataAddModelRepository, + "onnx_hash": sg.MetadataAddOnnxHash, + "run_id": sg.MetadataAddRunId, + "run_url": sg.MetadataAddRunUrl, +} +""" +Map of metadata field to function that serializes this field. +""" - This serializes the parsed graph representation into the flatbuffers-based - model format that this library uses. + +def build_metadata(builder: flatbuffers.Builder, metadata: Metadata): + """ + Serialize model metadata into a flatbuffers model. """ - builder = flatbuffers.Builder(initialSize=1024) + # Map of field name to flatbuffer string offset. + field_values = {} + + for field in METADATA_BUILDER_FNS.keys(): + if val := getattr(metadata, field): + field_values[field] = builder.CreateString(val) + + sg.MetadataStart(builder) + for field, builder_fn in METADATA_BUILDER_FNS.items(): + if val := field_values.get(field): + builder_fn(builder, val) + return sg.MetadataEnd(builder) + +def build_graph(builder: flatbuffers.Builder, graph: Graph): + """ + Serialize a computation graph into a flatbuffers model. + """ node_offsets = [] for node in graph.nodes: match node: @@ -1141,11 +1205,30 @@ def write_graph(graph: Graph, out_path: str): sg.GraphAddNodes(builder, graph_nodes) sg.GraphAddInputs(builder, inputs) sg.GraphAddOutputs(builder, outputs) - graph = sg.GraphEnd(builder) + return sg.GraphEnd(builder) + + +def write_model(graph: Graph, metadata: Metadata, out_path: str): + """ + Serialize a model into a flatbuffers model. + + This serializes the parsed graph representation into the flatbuffers-based + model format that this library uses. + + :param graph: The main graph for the model + :param metadata: Model metadata + :param out_path: Output .rten model path + """ + + builder = flatbuffers.Builder(initialSize=1024) + + graph = build_graph(builder, graph) + metadata = build_metadata(builder, metadata) sg.ModelStart(builder) sg.ModelAddSchemaVersion(builder, 1) sg.ModelAddGraph(builder, graph) + sg.ModelAddMetadata(builder, metadata) model = sg.ModelEnd(builder) builder.Finish(model) @@ -1155,23 +1238,57 @@ def write_graph(graph: Graph, out_path: str): output.write(data) +def sha256(filename: str) -> str: + """Generate SHA-256 hash of a file as a hex string.""" + hasher = hashlib.sha256() + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def generate_metadata(onnx_path: str, metadata_path: Optional[str] = None) -> Metadata: + """ + Generate metadata to embed into RTen model. + + :param onnx_path: Path to .onnx file + :param metadata_path: Path to JSON file containing additional metadata + """ + onnx_hash = sha256(onnx_path) + + fields = {"onnx_hash": onnx_hash} + if metadata_path: + with open(metadata_path) as fp: + metadata_dict = json.load(fp) + + for field in METADATA_BUILDER_FNS.keys(): + if field == "onnx_hash": + # This is handled separately. + continue + fields[field] = metadata_dict.get(field) + + return Metadata(**fields) + + def main(): parser = ArgumentParser(description="Convert ONNX models to .rten format.") parser.add_argument("model", help="Input ONNX model") + parser.add_argument( + "-m", "--metadata", help="Path to JSON file containing model metadata." + ) parser.add_argument("out_name", help="Output model file name", nargs="?") args = parser.parse_args() - model_path = args.model - - model = onnx.load(model_path) + model = onnx.load(args.model) graph = graph_from_onnx_graph(model.graph) + metadata = generate_metadata(args.model, args.metadata) output_path = args.out_name if output_path is None: - model_basename = splitext(model_path)[0] + model_basename = splitext(args.model)[0] output_path = f"{model_basename}.rten" - write_graph(graph, output_path) + write_model(graph, metadata, output_path) if __name__ == "__main__": diff --git a/tools/schema_generated.py b/tools/schema_generated.py index 6a7d975b..0c74d47f 100644 --- a/tools/schema_generated.py +++ b/tools/schema_generated.py @@ -4755,6 +4755,198 @@ def Pack(self, builder): return graph +class Metadata(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = Metadata() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsMetadata(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def MetadataBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed) + + # Metadata + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # Metadata + def OnnxHash(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def Description(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def License(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def Commit(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def CodeRepository(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def ModelRepository(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def RunId(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # Metadata + def RunUrl(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + +def MetadataStart(builder): + builder.StartObject(8) + +def MetadataAddOnnxHash(builder, onnxHash): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(onnxHash), 0) + +def MetadataAddDescription(builder, description): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(description), 0) + +def MetadataAddLicense(builder, license): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(license), 0) + +def MetadataAddCommit(builder, commit): + builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(commit), 0) + +def MetadataAddCodeRepository(builder, codeRepository): + builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(codeRepository), 0) + +def MetadataAddModelRepository(builder, modelRepository): + builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(modelRepository), 0) + +def MetadataAddRunId(builder, runId): + builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(runId), 0) + +def MetadataAddRunUrl(builder, runUrl): + builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(runUrl), 0) + +def MetadataEnd(builder): + return builder.EndObject() + + + +class MetadataT(object): + + # MetadataT + def __init__(self): + self.onnxHash = None # type: str + self.description = None # type: str + self.license = None # type: str + self.commit = None # type: str + self.codeRepository = None # type: str + self.modelRepository = None # type: str + self.runId = None # type: str + self.runUrl = None # type: str + + @classmethod + def InitFromBuf(cls, buf, pos): + metadata = Metadata() + metadata.Init(buf, pos) + return cls.InitFromObj(metadata) + + @classmethod + def InitFromPackedBuf(cls, buf, pos=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos) + return cls.InitFromBuf(buf, pos+n) + + @classmethod + def InitFromObj(cls, metadata): + x = MetadataT() + x._UnPack(metadata) + return x + + # MetadataT + def _UnPack(self, metadata): + if metadata is None: + return + self.onnxHash = metadata.OnnxHash() + self.description = metadata.Description() + self.license = metadata.License() + self.commit = metadata.Commit() + self.codeRepository = metadata.CodeRepository() + self.modelRepository = metadata.ModelRepository() + self.runId = metadata.RunId() + self.runUrl = metadata.RunUrl() + + # MetadataT + def Pack(self, builder): + if self.onnxHash is not None: + onnxHash = builder.CreateString(self.onnxHash) + if self.description is not None: + description = builder.CreateString(self.description) + if self.license is not None: + license = builder.CreateString(self.license) + if self.commit is not None: + commit = builder.CreateString(self.commit) + if self.codeRepository is not None: + codeRepository = builder.CreateString(self.codeRepository) + if self.modelRepository is not None: + modelRepository = builder.CreateString(self.modelRepository) + if self.runId is not None: + runId = builder.CreateString(self.runId) + if self.runUrl is not None: + runUrl = builder.CreateString(self.runUrl) + MetadataStart(builder) + if self.onnxHash is not None: + MetadataAddOnnxHash(builder, onnxHash) + if self.description is not None: + MetadataAddDescription(builder, description) + if self.license is not None: + MetadataAddLicense(builder, license) + if self.commit is not None: + MetadataAddCommit(builder, commit) + if self.codeRepository is not None: + MetadataAddCodeRepository(builder, codeRepository) + if self.modelRepository is not None: + MetadataAddModelRepository(builder, modelRepository) + if self.runId is not None: + MetadataAddRunId(builder, runId) + if self.runUrl is not None: + MetadataAddRunUrl(builder, runUrl) + metadata = MetadataEnd(builder) + return metadata + + class Model(object): __slots__ = ['_tab'] @@ -4794,8 +4986,18 @@ def Graph(self): return obj return None + # Model + def Metadata(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + x = self._tab.Indirect(o + self._tab.Pos) + obj = Metadata() + obj.Init(self._tab.Bytes, x) + return obj + return None + def ModelStart(builder): - builder.StartObject(2) + builder.StartObject(3) def ModelAddSchemaVersion(builder, schemaVersion): builder.PrependInt32Slot(0, schemaVersion, 0) @@ -4803,6 +5005,9 @@ def ModelAddSchemaVersion(builder, schemaVersion): def ModelAddGraph(builder, graph): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(graph), 0) +def ModelAddMetadata(builder, metadata): + builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(metadata), 0) + def ModelEnd(builder): return builder.EndObject() @@ -4818,6 +5023,7 @@ class ModelT(object): def __init__(self): self.schemaVersion = 0 # type: int self.graph = None # type: Optional[GraphT] + self.metadata = None # type: Optional[MetadataT] @classmethod def InitFromBuf(cls, buf, pos): @@ -4843,15 +5049,21 @@ def _UnPack(self, model): self.schemaVersion = model.SchemaVersion() if model.Graph() is not None: self.graph = GraphT.InitFromObj(model.Graph()) + if model.Metadata() is not None: + self.metadata = MetadataT.InitFromObj(model.Metadata()) # ModelT def Pack(self, builder): if self.graph is not None: graph = self.graph.Pack(builder) + if self.metadata is not None: + metadata = self.metadata.Pack(builder) ModelStart(builder) ModelAddSchemaVersion(builder, self.schemaVersion) if self.graph is not None: ModelAddGraph(builder, graph) + if self.metadata is not None: + ModelAddMetadata(builder, metadata) model = ModelEnd(builder) return model