Skip to content

Commit

Permalink
Merge pull request #202 from robertknight/rten-cli-multiple-iters
Browse files Browse the repository at this point in the history
Add `--n_iters` flag to CLI and perform constant propagation before running model
  • Loading branch information
robertknight authored May 21, 2024
2 parents c7dbdc2 + 80cbf77 commit 2c24dbb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
65 changes: 55 additions & 10 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ struct Args {

/// Sizes for dynamic dimensions of inputs.
input_sizes: Vec<DimSize>,

/// Number of times to run model.
n_iters: u32,
}

/// Specifies the size for a dynamic input dimension.
Expand Down Expand Up @@ -92,6 +95,8 @@ fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();

let mut n_iters = 1;
let mut timing = false;
let mut verbose = false;
let mut input_sizes = Vec::new();
Expand All @@ -100,6 +105,12 @@ fn parse_args() -> Result<Args, lexopt::Error> {
while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Short('n') | Long("n_iters") => {
let value = parser.value()?.string()?;
n_iters = value
.parse()
.map_err(|_| format!("Unable to parse `n_iters`"))?;
}
Short('v') | Long("verbose") => verbose = true,
Short('V') | Long("version") => {
println!("rten {}", env!("CARGO_PKG_VERSION"));
Expand All @@ -124,6 +135,9 @@ Args:
Options:
-h, --help Print help
-n, --n_iters <n>
Number of times to evaluate model
-t, --timing Output timing info
-s, --size <spec>
Expand All @@ -145,6 +159,7 @@ Options:

Ok(Args {
model,
n_iters,
timing,
verbose,
input_sizes,
Expand Down Expand Up @@ -185,6 +200,7 @@ fn run_with_random_input(
model: &Model,
dim_sizes: &[DimSize],
run_opts: RunOptions,
n_iters: u32,
) -> Result<(), Box<dyn Error>> {
let mut rng = fastrand::Rng::new();

Expand Down Expand Up @@ -257,7 +273,7 @@ fn run_with_random_input(
)?;

// Convert inputs from `Output` (owned) to `Input` (view).
let inputs: Vec<(NodeId, Input)> = inputs
let mut inputs: Vec<(NodeId, Input)> = inputs
.iter()
.map(|(id, output)| (*id, Input::from(output)))
.collect();
Expand All @@ -271,17 +287,45 @@ fn run_with_random_input(
println!(" Input \"{name}\" generated shape {:?}", input.shape());
}

// Run model and summarize outputs.
let start = Instant::now();
let outputs = model.run(&inputs, model.output_ids(), Some(run_opts))?;
let elapsed = start.elapsed().as_millis();
// Evaluate operators that don't depend on any inputs.
//
// ONNX Runtime does this when graph optimizations are enabled. RTen
// doesn't have any built-in graph optimizations yet, so we have to do this
// manually.
let opt_start = Instant::now();
let const_prop = model.partial_run(&[], model.output_ids(), None)?;
for (node_id, const_val) in const_prop.iter() {
inputs.push((*node_id, const_val.into()));
}
let opt_elapsed = opt_start.elapsed().as_millis();
if const_prop.len() > 0 {
println!(
" Constant propagation produced {} values in {:.2}ms",
const_prop.len(),
opt_elapsed
);
}

// Run model and summarize outputs.
println!();
println!(
" Model returned {} outputs in {:.2}ms.",
outputs.len(),
elapsed
);
let mut remaining_iters = n_iters.max(1);
let mut outputs;
loop {
let start = Instant::now();
outputs = model.run(&inputs, model.output_ids(), Some(run_opts.clone()))?;
let elapsed = start.elapsed().as_millis();

println!(
" Model returned {} outputs in {:.2}ms.",
outputs.len(),
elapsed
);

remaining_iters -= 1;
if remaining_iters == 0 {
break;
}
}
println!();

let output_names: Vec<String> = model
Expand Down Expand Up @@ -382,6 +426,7 @@ fn main() -> Result<(), Box<dyn Error>> {
verbose: args.verbose,
..Default::default()
},
args.n_iters,
)?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl Error for RunError {}

/// Options that control logging and other behaviors when executing a
/// [Model](crate::Model).
#[derive(Default)]
#[derive(Clone, Default, PartialEq)]
pub struct RunOptions {
/// Whether to log times spent in different operators when run completes.
pub timing: bool,
Expand Down
2 changes: 1 addition & 1 deletion src/timing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ pub struct TimingRecord<'a> {
}

/// Specifies sort order for graph run timings.
#[derive(Default)]
#[derive(Clone, Default, PartialEq)]
pub enum TimingSort {
/// Sort timings by operator name
ByName,
Expand Down

0 comments on commit 2c24dbb

Please sign in to comment.