Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --n_iters flag to CLI and perform constant propagation before running model #202

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ struct Args {

/// Sizes for dynamic dimensions of inputs.
input_sizes: Vec<DimSize>,

/// Number of times to run model.
n_iters: u32,
}

/// Specifies the size for a dynamic input dimension.
Expand Down Expand Up @@ -92,6 +95,8 @@ fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();

let mut n_iters = 1;
let mut timing = false;
let mut verbose = false;
let mut input_sizes = Vec::new();
Expand All @@ -100,6 +105,12 @@ fn parse_args() -> Result<Args, lexopt::Error> {
while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Short('n') | Long("n_iters") => {
let value = parser.value()?.string()?;
n_iters = value
.parse()
.map_err(|_| format!("Unable to parse `n_iters`"))?;
}
Short('v') | Long("verbose") => verbose = true,
Short('V') | Long("version") => {
println!("rten {}", env!("CARGO_PKG_VERSION"));
Expand All @@ -124,6 +135,9 @@ Args:

Options:
-h, --help Print help
-n, --n_iters <n>
Number of times to evaluate model

-t, --timing Output timing info

-s, --size <spec>
Expand All @@ -145,6 +159,7 @@ Options:

Ok(Args {
model,
n_iters,
timing,
verbose,
input_sizes,
Expand Down Expand Up @@ -185,6 +200,7 @@ fn run_with_random_input(
model: &Model,
dim_sizes: &[DimSize],
run_opts: RunOptions,
n_iters: u32,
) -> Result<(), Box<dyn Error>> {
let mut rng = fastrand::Rng::new();

Expand Down Expand Up @@ -257,7 +273,7 @@ fn run_with_random_input(
)?;

// Convert inputs from `Output` (owned) to `Input` (view).
let inputs: Vec<(NodeId, Input)> = inputs
let mut inputs: Vec<(NodeId, Input)> = inputs
.iter()
.map(|(id, output)| (*id, Input::from(output)))
.collect();
Expand All @@ -271,17 +287,45 @@ fn run_with_random_input(
println!(" Input \"{name}\" generated shape {:?}", input.shape());
}

// Run model and summarize outputs.
let start = Instant::now();
let outputs = model.run(&inputs, model.output_ids(), Some(run_opts))?;
let elapsed = start.elapsed().as_millis();
// Evaluate operators that don't depend on any inputs.
//
// ONNX Runtime does this when graph optimizations are enabled. RTen
// doesn't have any built-in graph optimizations yet, so we have to do this
// manually.
let opt_start = Instant::now();
let const_prop = model.partial_run(&[], model.output_ids(), None)?;
for (node_id, const_val) in const_prop.iter() {
inputs.push((*node_id, const_val.into()));
}
let opt_elapsed = opt_start.elapsed().as_millis();
if const_prop.len() > 0 {
println!(
" Constant propagation produced {} values in {:.2}ms",
const_prop.len(),
opt_elapsed
);
}

// Run model and summarize outputs.
println!();
println!(
" Model returned {} outputs in {:.2}ms.",
outputs.len(),
elapsed
);
let mut remaining_iters = n_iters.max(1);
let mut outputs;
loop {
let start = Instant::now();
outputs = model.run(&inputs, model.output_ids(), Some(run_opts.clone()))?;
let elapsed = start.elapsed().as_millis();

println!(
" Model returned {} outputs in {:.2}ms.",
outputs.len(),
elapsed
);

remaining_iters -= 1;
if remaining_iters == 0 {
break;
}
}
println!();

let output_names: Vec<String> = model
Expand Down Expand Up @@ -382,6 +426,7 @@ fn main() -> Result<(), Box<dyn Error>> {
verbose: args.verbose,
..Default::default()
},
args.n_iters,
)?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl Error for RunError {}

/// Options that control logging and other behaviors when executing a
/// [Model](crate::Model).
#[derive(Default)]
#[derive(Clone, Default, PartialEq)]
pub struct RunOptions {
/// Whether to log times spent in different operators when run completes.
pub timing: bool,
Expand Down
2 changes: 1 addition & 1 deletion src/timing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ pub struct TimingRecord<'a> {
}

/// Specifies sort order for graph run timings.
#[derive(Default)]
#[derive(Clone, Default, PartialEq)]
pub enum TimingSort {
/// Sort timings by operator name
ByName,
Expand Down
Loading