Skip to content

Commit

Permalink
Bump rten to v0.13.0
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Aug 24, 2024
1 parent 6f37b2b commit 448c6bc
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ members = [
]

[workspace.dependencies]
rten = { version = "0.12.0" }
rten-imageproc = { version = "0.12.0" }
rten-tensor = { version = "0.12.0" }
rten = { version = "0.13.0" }
rten-imageproc = { version = "0.13.0" }
rten-tensor = { version = "0.13.0" }
42 changes: 26 additions & 16 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ mod tests {
/// bias to produce an output "probability map".
fn fake_detection_model() -> Model {
let mut mb = ModelBuilder::new(ModelFormat::V1);
let input_id = mb.add_value(
let mut gb = mb.graph_builder();

let input_id = gb.add_value(
"input",
Some(&[
Dimension::Symbolic("batch".to_string()),
Expand All @@ -252,20 +254,23 @@ mod tests {
Dimension::Fixed(100),
]),
);
mb.add_input(input_id);
gb.add_input(input_id);

let output_id = mb.add_value("output", None);
mb.add_output(output_id);
let output_id = gb.add_value("output", None);
gb.add_output(output_id);

let bias = Tensor::from_scalar(0.5);
let bias_id = mb.add_constant(bias.view());
mb.add_operator(
let bias_id = gb.add_constant(bias.view());
gb.add_operator(
"add",
OpType::Add,
&[Some(input_id), Some(bias_id)],
&[output_id],
);

let graph = gb.finish();
mb.set_graph(graph);

let model_data = mb.finish();
Model::load(model_data).unwrap()
}
Expand All @@ -278,7 +283,9 @@ mod tests {
/// each column of the input as a one-hot vector of probabilities.
fn fake_recognition_model() -> Model {
let mut mb = ModelBuilder::new(ModelFormat::V1);
let input_id = mb.add_value(
let mut gb = mb.graph_builder();

let input_id = gb.add_value(
"input",
Some(&[
Dimension::Symbolic("batch".to_string()),
Expand All @@ -287,11 +294,11 @@ mod tests {
Dimension::Symbolic("seq".to_string()),
]),
);
mb.add_input(input_id);
gb.add_input(input_id);

// MaxPool to scale width by 1/4: NCHW => NCHW/4
let pool_out = mb.add_value("max_pool_out", None);
mb.add_operator(
let pool_out = gb.add_value("max_pool_out", None);
gb.add_operator(
"max_pool",
OpType::MaxPool(MaxPool {
kernel_size: [1, 4],
Expand All @@ -304,18 +311,18 @@ mod tests {

// Squeeze to remove the channel dim: NCHW/4 => NHW/4
let squeeze_axes = Tensor::from_vec(vec![1]);
let squeeze_axes_id = mb.add_constant(squeeze_axes.view());
let squeeze_out = mb.add_value("squeeze_out", None);
mb.add_operator(
let squeeze_axes_id = gb.add_constant(squeeze_axes.view());
let squeeze_out = gb.add_value("squeeze_out", None);
gb.add_operator(
"squeeze",
OpType::Squeeze,
&[Some(pool_out), Some(squeeze_axes_id)],
&[squeeze_out],
);

// Transpose: NHW/4 => W/4NH
let transpose_out = mb.add_value("transpose_out", None);
mb.add_operator(
let transpose_out = gb.add_value("transpose_out", None);
gb.add_operator(
"transpose",
OpType::Transpose(Transpose {
perm: Some(vec![2, 0, 1]),
Expand All @@ -324,7 +331,10 @@ mod tests {
&[transpose_out],
);

mb.add_output(transpose_out);
gb.add_output(transpose_out);
let graph = gb.finish();

mb.set_graph(graph);

let model_data = mb.finish();
Model::load(model_data).unwrap()
Expand Down

0 comments on commit 448c6bc

Please sign in to comment.