Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for model metadata in RTen models #48

Merged
merged 4 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::error::Error;
use std::fs;
use std::time::Instant;

use rten::{Dimension, Input, Model, NodeId, Output, RunOptions};
use rten::{Dimension, Input, Model, ModelMetadata, NodeId, Output, RunOptions};
use rten_tensor::prelude::*;
use rten_tensor::Tensor;

Expand Down Expand Up @@ -37,6 +37,11 @@ fn parse_args() -> Result<Args, lexopt::Error> {

Usage: {bin_name} [OPTIONS] <model>

Args:
<model>
Path to '.rten' model to inspect and run.

Options:
-t, --timing Output timing info
-v, --verbose Enable verbose logging
-h, --help Print help
Expand Down Expand Up @@ -66,28 +71,27 @@ fn format_param_count(n: usize) -> String {
}
}

/// Tool for inspecting converted ONNX models and running them with randomly
/// generated inputs.
///
/// ```
/// tools/convert-onnx.py model.onnx output.rten
/// cargo run -p rten-cli --release output.rten
/// ```
///
/// To get detailed timing information set the `RTEN_TIMING` env var before
/// running. See `docs/profiling.md`.
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let model_bytes = fs::read(args.model)?;
let model = Model::load(&model_bytes)?;
fn print_metadata(metadata: &ModelMetadata) {
fn print_field<T: std::fmt::Display>(name: &str, value: Option<T>) {
if let Some(value) = value {
println!(" {}: {}", name, value);
}
}

println!(
"Model stats: {} inputs, {} outputs, {} params",
model.input_ids().len(),
model.output_ids().len(),
format_param_count(model.total_params()),
);
println!("Metadata:");
print_field("ONNX hash", metadata.onnx_hash());
print_field("Description", metadata.description());
print_field("License", metadata.license());
print_field("Commit", metadata.commit());
print_field("Repository", metadata.code_repository());
print_field("Model repository", metadata.model_repository());
print_field("Run ID", metadata.run_id());
print_field("Run URL", metadata.run_url());
}

/// Generate random inputs for `model` using shape metadata and heuristics,
/// run it, and print details of the output.
fn run_with_random_input(model: &Model, run_opts: RunOptions) -> Result<(), Box<dyn Error>> {
let mut rng = fastrand::Rng::new();

// Generate random ints that are likely to be valid token IDs in a language
Expand Down Expand Up @@ -165,27 +169,21 @@ fn main() -> Result<(), Box<dyn Error>> {
.as_ref()
.and_then(|ni| ni.name())
.unwrap_or("(unnamed)");
println!("Input \"{name}\" resolved shape {:?}", input.shape());
println!(" Input \"{name}\" generated shape {:?}", input.shape());
}

// Run model and summarize outputs.
let start = Instant::now();
let outputs = model.run(
&inputs,
model.output_ids(),
Some(RunOptions {
timing: args.timing,
verbose: args.verbose,
..Default::default()
}),
)?;
let outputs = model.run(&inputs, model.output_ids(), Some(run_opts))?;
let elapsed = start.elapsed().as_millis();

println!();
println!(
"Model returned {} outputs in {:.2}ms",
" Model returned {} outputs in {:.2}ms.",
outputs.len(),
elapsed
);
println!();

let output_names: Vec<String> = model
.output_ids()
Expand All @@ -204,11 +202,50 @@ fn main() -> Result<(), Box<dyn Error>> {
Output::IntTensor(_) => "i32",
};
println!(
"Output {i} \"{name}\" data type {} shape: {:?}",
" Output {i} \"{name}\" data type {} shape: {:?}",
dtype,
output.shape()
);
}

Ok(())
}

/// Tool for inspecting converted ONNX models and running them with randomly
/// generated inputs.
///
/// ```
/// tools/convert-onnx.py model.onnx output.rten
/// cargo run -p rten-cli --release output.rten
/// ```
///
/// To get detailed timing information set the `RTEN_TIMING` env var before
/// running. See `docs/profiling.md`.
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let model_bytes = fs::read(args.model)?;
let model = Model::load(&model_bytes)?;

println!(
"Model summary: {} inputs, {} outputs, {} params",
model.input_ids().len(),
model.output_ids().len(),
format_param_count(model.total_params()),
);
println!();

print_metadata(model.metadata());

println!();
println!("Running model with random inputs...");
run_with_random_input(
&model,
RunOptions {
timing: args.timing,
verbose: args.verbose,
..Default::default()
},
)?;

Ok(())
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod gemm;
mod graph;
mod iter_util;
mod model;
mod model_metadata;
mod number;
mod slice_reductions;
mod timer;
Expand All @@ -44,6 +45,7 @@ pub mod ops;

pub use graph::{Dimension, NodeId, RunOptions};
pub use model::{DefaultOperatorFactory, Model, ModelLoadError, NodeInfo, OpRegistry};
pub use model_metadata::ModelMetadata;
pub use ops::{FloatOperators, Input, Operators, Output};
pub use timer::Timer;
pub use timing::TimingSort;
Expand Down
27 changes: 26 additions & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rten_tensor::Tensor;
use smallvec::smallvec;

use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions};
use crate::model_metadata::ModelMetadata;
use crate::ops;
use crate::ops::{
BoxOrder, CoordTransformMode, DataType, Direction, Input, NearestMode, Operator, Output,
Expand Down Expand Up @@ -52,6 +53,7 @@ pub struct Model {
input_ids: Vec<NodeId>,
output_ids: Vec<NodeId>,
graph: Graph,
metadata: ModelMetadata,
}

/// Provides access to metadata about a graph node.
Expand Down Expand Up @@ -144,6 +146,11 @@ impl Model {
self.graph.get_node(id).map(|node| NodeInfo { node })
}

/// Return metadata about the model.
pub fn metadata(&self) -> &ModelMetadata {
&self.metadata
}

/// Return the IDs of input nodes.
pub fn input_ids(&self) -> &[NodeId] {
&self.input_ids
Expand Down Expand Up @@ -1108,11 +1115,17 @@ fn load_model(data: &[u8], registry: &OpRegistry) -> Result<Model, ModelLoadErro
}
}

let metadata = model
.metadata()
.map(ModelMetadata::deserialize)
.unwrap_or_default();

let model = Model {
node_ids: node_id_from_name,
input_ids,
output_ids,
graph,
metadata,
};
Ok(model)
}
Expand All @@ -1126,7 +1139,7 @@ mod tests {

use crate::graph::{Dimension, RunError};
use crate::model::Model;
use crate::model_builder::{ModelBuilder, OpType};
use crate::model_builder::{MetadataArgs, ModelBuilder, OpType};
use crate::ops;
use crate::ops::{BoxOrder, CoordTransformMode, NearestMode, OpError, ResizeMode, Scalar};

Expand Down Expand Up @@ -1157,6 +1170,10 @@ mod tests {
);
builder.add_operator("relu", OpType::Relu, &[Some(concat_out)], &[output_node]);

builder.add_metadata(MetadataArgs {
onnx_hash: Some("abc".to_string()),
});

builder.finish()
}

Expand Down Expand Up @@ -1198,6 +1215,14 @@ mod tests {
assert_eq!(shape, &[1, 2, 2].map(Dimension::Fixed));
}

#[test]
fn test_metadata() {
let buffer = generate_model_buffer();
let model = Model::load(&buffer).unwrap();
assert_eq!(model.metadata().onnx_hash(), Some("abc"));
assert_eq!(model.metadata().description(), None);
}

#[test]
fn test_input_shape() {
let buffer = generate_model_buffer();
Expand Down
21 changes: 21 additions & 0 deletions src/model_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub struct ModelBuilder<'a> {
nodes: Vec<WIPOffset<sg::Node<'a>>>,
input_ids: Vec<u32>,
output_ids: Vec<u32>,
metadata: Option<WIPOffset<sg::Metadata<'a>>>,
}

enum NodeData<'a> {
Expand All @@ -126,6 +127,11 @@ enum NodeData<'a> {
Operator(WIPOffset<sg::OperatorNode<'a>>),
}

/// Arguments for [ModelBuilder::add_metadata].
pub struct MetadataArgs {
pub onnx_hash: Option<String>,
}

struct PadArgs {
pad_mode: sg::PadMode,
pads: Option<Vec<usize>>,
Expand All @@ -152,6 +158,7 @@ impl<'a> ModelBuilder<'a> {
nodes: Vec::new(),
input_ids: Vec::new(),
output_ids: Vec::new(),
metadata: None,
}
}

Expand Down Expand Up @@ -683,6 +690,19 @@ impl<'a> ModelBuilder<'a> {
self.output_ids.push(node_id);
}

/// Add model metadata
pub fn add_metadata(&mut self, metadata: MetadataArgs) {
let hash = metadata
.onnx_hash
.as_ref()
.map(|hash| self.builder.create_string(hash));
let mut meta_builder = sg::MetadataBuilder::new(&mut self.builder);
if let Some(hash) = hash {
meta_builder.add_onnx_hash(hash);
}
self.metadata = Some(meta_builder.finish());
}

/// Finish writing the model data to the buffer and return the buffer's contents.
pub fn finish(mut self) -> Vec<u8> {
let inputs_vec = self.builder.create_vector(&self.input_ids[..]);
Expand All @@ -703,6 +723,7 @@ impl<'a> ModelBuilder<'a> {
&sg::ModelArgs {
schema_version: 1,
graph: Some(graph),
metadata: self.metadata,
},
);

Expand Down
Loading
Loading