Skip to content

Commit

Permalink
Add TrOCR example
Browse files Browse the repository at this point in the history
  • Loading branch information
robertknight committed Aug 14, 2024
1 parent 1df58a9 commit 73e0888
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 0 deletions.
4 changes: 4 additions & 0 deletions rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ path = "src/depth_anything.rs"
name = "segment_anything"
path = "src/segment_anything.rs"

[[bin]]
name = "trocr"
path = "src/trocr.rs"

# Text
[[bin]]
name = "bert_qa"
Expand Down
1 change: 1 addition & 0 deletions rten-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ The examples have been chosen to cover common tasks and popular models.
- **detr** - Object detection using [DETR](https://research.facebook.com/publications/end-to-end-object-detection-with-transformers/)
- **distilvit** - Image captioning using [Mozilla's DistilViT](https://hacks.mozilla.org/2024/05/experimenting-with-local-alt-text-generation-in-firefox-nightly/)
- **segment_anything** - Image segmentation using [Segment Anything](https://segment-anything.com)
- **trocr** - Recognize text using [TrOCR](https://arxiv.org/abs/2109.10282)
- **yolo** - Object detection using [YOLO v8](https://github.com/ultralytics/ultralytics)

### Text
Expand Down
145 changes: 145 additions & 0 deletions rten-examples/src/trocr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fs;
use std::io::prelude::*;

use rten::{FloatOperators, Model};
use rten_generate::{Generator, GeneratorUtils};
use rten_imageio::read_image;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorViewMut};
use rten_text::tokenizers::Tokenizer;

struct Args {
encoder_model: String,
decoder_model: String,
tokenizer_config: String,
image_path: String,
}

fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();

while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Long("help") => {
println!(
"Generate a caption for an image.
Usage: {bin_name} [options] <encoder_model> <decoder_model> <tokenizer> <image>
Args:
<encoder_model> - Image encoder model
<decoder_model> - Text decoder model
<tokenizer> - `tokenizer.json` file
<image> - Image path
",
bin_name = parser.bin_name().unwrap_or("distilvit")
);
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}

let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?;
let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?;
let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?;
let image_path = values.pop_front().ok_or("missing `image_path` arg")?;

let args = Args {
encoder_model,
decoder_model,
tokenizer_config,
image_path,
};

Ok(args)
}

fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
let std_dev = [0.5, 0.5, 0.5];

(value - mean[channel]) / std_dev[channel]
}

fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for chan in 0..img.size(0) {
img.slice_mut::<2, _>(chan)
.apply(|x| normalize_pixel(*x, chan));
}
}

/// Recognize text line images using TrOCR [^1].
///
/// First use Hugging Face's Optimum tool to download and export the models to
/// ONNX:
///
/// ```
/// optimum-cli export onnx --model microsoft/trocr-base-printed trocr-base-printed
/// ```
///
/// Convert the models to `.rten` format:
///
/// ```
/// rten-convert trocr-base-printed/encoder_model.onnx
/// rten-convert trocr-base-printed/decoder_model.onnx
/// ```
///
/// Run the model, specifying the image to recognize:
///
/// ```sh
/// cargo run --release --bin trocr trocr-base-printed/encoder_model.rten trocr-base-printed/decoder_model.rten tokenizer.json <image>
/// ```
///
/// [^1]: https://arxiv.org/abs/2109.10282
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let encoder_model = unsafe { Model::load_mmap(args.encoder_model)? };
let decoder_model = unsafe { Model::load_mmap(args.decoder_model)? };
let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?;
let tokenizer = Tokenizer::from_json(&tokenizer_config)?;
let mut image = read_image(args.image_path)?.into_dyn();
image.insert_axis(0); // Add batch dim
let mut image = image.resize_image([384, 384])?;
normalize_image(image.slice_mut(0));

let encoded_image: NdTensor<f32, 3> = encoder_model
.run_one(image.view().into(), None)?
.try_into()?;

let encoder_hidden_states_id = decoder_model.node_id("encoder_hidden_states")?;

// `decoder_start_token_id` from `generation_config.json`. This is the `</s>`
// token.
let decoder_start_token = 2;
let eos_token = 2;

let max_tokens = 100;

let prompt = vec![decoder_start_token];
let generator = Generator::from_model(&decoder_model)?
.with_prompt(&prompt)
.with_constant_input(encoder_hidden_states_id, encoded_image.view().into())
.stop_on_tokens([eos_token])
.take(max_tokens)
.decode(&tokenizer);

for token in generator {
let token = token?;

print!("{}", token);
let _ = std::io::stdout().flush();
}

Ok(())
}

0 comments on commit 73e0888

Please sign in to comment.