Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrOCR example #304

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rten-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ fn run_with_random_input(
// of these.
name if name.ends_with("_ids") => Output::from(Tensor::<i32>::zeros(&resolved_shape)),

// Optimum can export "merged" transformer models which have two
// branches. One accepts KV-cache inputs and the other does not.
// Set this to false as a "safer" value because we don't have
// cached outputs from a previous run.
"use_cache_branch" => Output::from(Tensor::from(0i32)),

// For anything else, random floats in [0, 1].
//
// TODO - Value nodes in the model should include data types,
Expand Down
4 changes: 4 additions & 0 deletions rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ path = "src/depth_anything.rs"
name = "segment_anything"
path = "src/segment_anything.rs"

[[bin]]
name = "trocr"
path = "src/trocr.rs"

# Text
[[bin]]
name = "bert_qa"
Expand Down
1 change: 1 addition & 0 deletions rten-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ The examples have been chosen to cover common tasks and popular models.
- **detr** - Object detection using [DETR](https://research.facebook.com/publications/end-to-end-object-detection-with-transformers/)
- **distilvit** - Image captioning using [Mozilla's DistilViT](https://hacks.mozilla.org/2024/05/experimenting-with-local-alt-text-generation-in-firefox-nightly/)
- **segment_anything** - Image segmentation using [Segment Anything](https://segment-anything.com)
- **trocr** - Recognize text using [TrOCR](https://arxiv.org/abs/2109.10282)
- **yolo** - Object detection using [YOLO v8](https://github.com/ultralytics/ultralytics)

### Text
Expand Down
148 changes: 148 additions & 0 deletions rten-examples/src/trocr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fs;
use std::io::prelude::*;

use rten::{FloatOperators, Model};
use rten_generate::{Generator, GeneratorUtils};
use rten_imageio::read_image;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorViewMut};
use rten_text::tokenizers::Tokenizer;

struct Args {
encoder_model: String,
decoder_model: String,
tokenizer_config: String,
image_path: String,
}

fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();

while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Long("help") => {
println!(
"Read text from an image containing a single text line.

Usage: {bin_name} [options] <encoder_model> <decoder_model> <tokenizer> <image>

Args:

<encoder_model> - Image encoder model
<decoder_model> - Text decoder model
<tokenizer> - `tokenizer.json` file
<image> - Image path
",
bin_name = parser.bin_name().unwrap_or("distilvit")
);
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}

let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?;
let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?;
let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?;
let image_path = values.pop_front().ok_or("missing `image_path` arg")?;

let args = Args {
encoder_model,
decoder_model,
tokenizer_config,
image_path,
};

Ok(args)
}

fn normalize_pixel(value: f32, channel: usize) -> f32 {
assert!(channel < 3, "channel index is invalid");

// Values taken from `preprocessor_config.json`.
let mean = [0.5, 0.5, 0.5];
let std_dev = [0.5, 0.5, 0.5];

(value - mean[channel]) / std_dev[channel]
}

fn normalize_image(mut img: NdTensorViewMut<f32, 3>) {
for chan in 0..img.size(0) {
img.slice_mut::<2, _>(chan)
.apply(|x| normalize_pixel(*x, chan));
}
}

/// Recognize text line images using TrOCR [^1].
///
/// First use Hugging Face's Optimum tool to download and export the models to
/// ONNX:
///
/// ```
/// optimum-cli export onnx --model microsoft/trocr-base-printed trocr-base-printed
/// ```
///
/// Convert the models to `.rten` format. For the decoder you need to use the
/// "merged" model.
///
/// ```
/// rten-convert trocr-base-printed/encoder_model.onnx
/// rten-convert trocr-base-printed/decoder_model_merged.onnx
/// ```
///
/// Run the model, specifying the image to recognize:
///
/// ```sh
/// cargo run --release --bin trocr trocr-base-printed/encoder_model.rten trocr-base-printed/decoder_model_merged.rten tokenizer.json <image>
/// ```
///
/// [^1]: https://arxiv.org/abs/2109.10282
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let encoder_model = unsafe { Model::load_mmap(args.encoder_model)? };
let decoder_model = unsafe { Model::load_mmap(args.decoder_model)? };
let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?;
let tokenizer = Tokenizer::from_json(&tokenizer_config)?;
let mut image = read_image(args.image_path)?.into_dyn();
image.insert_axis(0); // Add batch dim

// From `image_size` in config.json.
let mut image = image.resize_image([384, 384])?;
normalize_image(image.slice_mut(0));

let encoded_image: NdTensor<f32, 3> = encoder_model
.run_one(image.view().into(), None)?
.try_into()?;

let encoder_hidden_states_id = decoder_model.node_id("encoder_hidden_states")?;

// `decoder_start_token_id` from `generation_config.json`. This is the `</s>`
// token.
let decoder_start_token = 2;
let eos_token = 2;

let max_tokens = 100;

let prompt = vec![decoder_start_token];
let generator = Generator::from_model(&decoder_model)?
.with_prompt(&prompt)
.with_constant_input(encoder_hidden_states_id, encoded_image.view().into())
.stop_on_tokens([eos_token])
.take(max_tokens)
.decode(&tokenizer);

for token in generator {
let token = token?;

print!("{}", token);
let _ = std::io::stdout().flush();
}

Ok(())
}
2 changes: 0 additions & 2 deletions rten-text/src/tokenizers/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ impl Bpe {

if let Some(rank) = builder.get_token_rank(&token) {
rank_to_token_id.insert(rank, id);
} else if !added_tokens.values().any(|s| *s == token.as_str()) {
return Err(BpeError::InvalidVocabEntry(token));
}
}
(Some(rank_to_token_id), Some(token_id_to_encoded_bytes))
Expand Down
Loading