Skip to content

Commit

Permalink
Support specifying custom input shapes in rten cli
Browse files Browse the repository at this point in the history
This enables experimenting with how performance changes when varying the input
shape, or testing special cases of a shape.
  • Loading branch information
robertknight committed Feb 5, 2024
1 parent 4515849 commit 485f221
Showing 1 changed file with 65 additions and 16 deletions.
81 changes: 65 additions & 16 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use std::error::Error;
use std::fs;
use std::time::Instant;
Expand All @@ -16,6 +16,30 @@ struct Args {

/// Enable verbose logging for model execution.
verbose: bool,

/// Map of `(input_name, dims)` with custom shapes for inputs.
input_shapes: HashMap<String, Vec<usize>>,
}

/// Parse an input shape specifier in the form `input_name=dim0,dim1,...`.
///
/// Returns a tuple of (name, shape).
fn parse_shape_spec(spec: &str) -> Result<(String, Vec<usize>), lexopt::Error> {
let parts: Vec<&str> = spec.split('=').collect();
if parts.len() != 2 {
return Err(lexopt::Error::Custom(
"Invalid input format. Expected input_name=dim0,dim1,...".into(),
));
}

let name = parts[0].to_string();
let dims_str = parts[1];
let parsed_dims: Result<Vec<usize>, _> = dims_str.split(',').map(|dim| dim.parse()).collect();

match parsed_dims {
Ok(dims) => Ok((name, dims)),
Err(e) => Err(lexopt::Error::Custom(e.into())),
}
}

fn parse_args() -> Result<Args, lexopt::Error> {
Expand All @@ -24,13 +48,19 @@ fn parse_args() -> Result<Args, lexopt::Error> {
let mut values = VecDeque::new();
let mut timing = false;
let mut verbose = false;
let mut input_shapes = HashMap::new();

let mut parser = lexopt::Parser::from_env();
while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Short('v') | Long("verbose") => verbose = true,
Short('t') | Long("timing") => timing = true,
Short('s') | Long("shape") => {
let value = parser.value()?.string()?;
let (name, shape) = parse_shape_spec(&value)?;
input_shapes.insert(name, shape);
}
Short('h') | Long("help") => {
println!(
"Inspect and run RTen models.
Expand All @@ -45,6 +75,9 @@ Options:
-t, --timing Output timing info
-v, --verbose Enable verbose logging
-h, --help Print help
-s, --shape <shape>
Specify shape for an input in the form `name=dim0,dim1,...`
",
bin_name = parser.bin_name().unwrap_or("rten")
);
Expand All @@ -60,6 +93,7 @@ Options:
model,
timing,
verbose,
input_shapes,
})
}

Expand Down Expand Up @@ -91,7 +125,15 @@ fn print_metadata(metadata: &ModelMetadata) {

/// Generate random inputs for `model` using shape metadata and heuristics,
/// run it, and print details of the output.
fn run_with_random_input(model: &Model, run_opts: RunOptions) -> Result<(), Box<dyn Error>> {
///
/// `custom_shapes` is a map of (input_name, dims) to use as shapes for inputs.
/// If a shape is not specified for an input, one is generated using heuristics
/// and the shape information specified by the model.
fn run_with_random_input(
model: &Model,
custom_shapes: &HashMap<String, Vec<usize>>,
run_opts: RunOptions,
) -> Result<(), Box<dyn Error>> {
let mut rng = fastrand::Rng::new();

// Generate random ints that are likely to be valid token IDs in a language
Expand All @@ -108,20 +150,26 @@ fn run_with_random_input(model: &Model, run_opts: RunOptions) -> Result<(), Box<
let shape = info
.shape()
.ok_or(format!("Unable to get shape for input {}", name))?;
let mut resolved_shape: Vec<usize> = Vec::new();
for dim in shape {
let size = match dim {
// Guess a suitable size for an input dimension based on
// the name.
Dimension::Symbolic(name) => match name.as_str() {
"batch" | "batch_size" => 1,
"sequence" | "sequence_length" => 128,
_ => 256,
},
Dimension::Fixed(size) => size,
};
resolved_shape.push(size)
}

let resolved_shape = if let Some(shape) = custom_shapes.get(name) {
shape.clone()
} else {
shape
.iter()
.map(|dim| {
match dim {
// Guess a suitable size for an input dimension based on
// the name.
Dimension::Symbolic(name) => match name.as_str() {
"batch" | "batch_size" => 1,
"sequence" | "sequence_length" => 128,
_ => 256,
},
Dimension::Fixed(size) => *size,
}
})
.collect()
};

// Guess suitable content for the input based on its name.
let tensor = match name {
Expand Down Expand Up @@ -278,6 +326,7 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("Running model with random inputs...");
run_with_random_input(
&model,
&args.input_shapes,
RunOptions {
timing: args.timing,
verbose: args.verbose,
Expand Down

0 comments on commit 485f221

Please sign in to comment.