From ddf41093f4340e272fee25f80f25fb79eefd8705 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 16 Aug 2024 22:38:20 +0100 Subject: [PATCH 1/3] Set `use_cache_branch` inputs to 0 by default in rten-cli --- rten-cli/src/main.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rten-cli/src/main.rs b/rten-cli/src/main.rs index 94ea5bed..e63f5d3b 100644 --- a/rten-cli/src/main.rs +++ b/rten-cli/src/main.rs @@ -258,6 +258,12 @@ fn run_with_random_input( // of these. name if name.ends_with("_ids") => Output::from(Tensor::::zeros(&resolved_shape)), + // Optimum can export "merged" transformer models which have two + // branches. One accepts KV-cache inputs and the other does not. + // Set this to false as a "safer" value because we don't have + // cached outputs from a previous run. + "use_cache_branch" => Output::from(Tensor::from(0i32)), + // For anything else, random floats in [0, 1]. // // TODO - Value nodes in the model should include data types, From 446b05cced374ecd08a8d464783be4088e8b7b82 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 11 Aug 2024 08:25:59 +0100 Subject: [PATCH 2/3] Add TrOCR example --- rten-examples/Cargo.toml | 4 + rten-examples/README.md | 1 + rten-examples/src/trocr.rs | 148 +++++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 rten-examples/src/trocr.rs diff --git a/rten-examples/Cargo.toml b/rten-examples/Cargo.toml index e0f4e254..8af7807c 100644 --- a/rten-examples/Cargo.toml +++ b/rten-examples/Cargo.toml @@ -64,6 +64,10 @@ path = "src/depth_anything.rs" name = "segment_anything" path = "src/segment_anything.rs" +[[bin]] +name = "trocr" +path = "src/trocr.rs" + # Text [[bin]] name = "bert_qa" diff --git a/rten-examples/README.md b/rten-examples/README.md index ac8d9aad..cff893b3 100644 --- a/rten-examples/README.md +++ b/rten-examples/README.md @@ -57,6 +57,7 @@ The examples have been chosen to cover common tasks and popular models. - **detr** - Object detection using [DETR](https://research.facebook.com/publications/end-to-end-object-detection-with-transformers/) - **distilvit** - Image captioning using [Mozilla's DistilViT](https://hacks.mozilla.org/2024/05/experimenting-with-local-alt-text-generation-in-firefox-nightly/) - **segment_anything** - Image segmentation using [Segment Anything](https://segment-anything.com) +- **trocr** - Recognize text using [TrOCR](https://arxiv.org/abs/2109.10282) - **yolo** - Object detection using [YOLO v8](https://github.com/ultralytics/ultralytics) ### Text diff --git a/rten-examples/src/trocr.rs b/rten-examples/src/trocr.rs new file mode 100644 index 00000000..3a4b4b5c --- /dev/null +++ b/rten-examples/src/trocr.rs @@ -0,0 +1,148 @@ +use std::collections::VecDeque; +use std::error::Error; +use std::fs; +use std::io::prelude::*; + +use rten::{FloatOperators, Model}; +use rten_generate::{Generator, GeneratorUtils}; +use rten_imageio::read_image; +use rten_tensor::prelude::*; +use rten_tensor::{NdTensor, NdTensorViewMut}; +use rten_text::tokenizers::Tokenizer; + +struct Args { + encoder_model: String, + decoder_model: String, + tokenizer_config: String, + image_path: String, +} + +fn parse_args() -> Result { + use lexopt::prelude::*; + + let mut values = VecDeque::new(); + let mut parser = lexopt::Parser::from_env(); + + while let Some(arg) = parser.next()? { + match arg { + Value(val) => values.push_back(val.string()?), + Long("help") => { + println!( + "Read text from an image containing a single text line. + +Usage: {bin_name} [options] + +Args: + + - Image encoder model + - Text decoder model + - `tokenizer.json` file + - Image path +", + bin_name = parser.bin_name().unwrap_or("distilvit") + ); + std::process::exit(0); + } + _ => return Err(arg.unexpected()), + } + } + + let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?; + let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?; + let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; + let image_path = values.pop_front().ok_or("missing `image_path` arg")?; + + let args = Args { + encoder_model, + decoder_model, + tokenizer_config, + image_path, + }; + + Ok(args) +} + +fn normalize_pixel(value: f32, channel: usize) -> f32 { + assert!(channel < 3, "channel index is invalid"); + + // Values taken from `preprocessor_config.json`. + let mean = [0.5, 0.5, 0.5]; + let std_dev = [0.5, 0.5, 0.5]; + + (value - mean[channel]) / std_dev[channel] +} + +fn normalize_image(mut img: NdTensorViewMut) { + for chan in 0..img.size(0) { + img.slice_mut::<2, _>(chan) + .apply(|x| normalize_pixel(*x, chan)); + } +} + +/// Recognize text line images using TrOCR [^1]. +/// +/// First use Hugging Face's Optimum tool to download and export the models to +/// ONNX: +/// +/// ``` +/// optimum-cli export onnx --model microsoft/trocr-base-printed trocr-base-printed +/// ``` +/// +/// Convert the models to `.rten` format. For the decoder you need to use the +/// "merged" model. +/// +/// ``` +/// rten-convert trocr-base-printed/encoder_model.onnx +/// rten-convert trocr-base-printed/decoder_model_merged.onnx +/// ``` +/// +/// Run the model, specifying the image to recognize: +/// +/// ```sh +/// cargo run --release --bin trocr trocr-base-printed/encoder_model.rten trocr-base-printed/decoder_model_merged.rten tokenizer.json +/// ``` +/// +/// [^1]: https://arxiv.org/abs/2109.10282 +fn main() -> Result<(), Box> { + let args = parse_args()?; + let encoder_model = unsafe { Model::load_mmap(args.encoder_model)? }; + let decoder_model = unsafe { Model::load_mmap(args.decoder_model)? }; + let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?; + let tokenizer = Tokenizer::from_json(&tokenizer_config)?; + let mut image = read_image(args.image_path)?.into_dyn(); + image.insert_axis(0); // Add batch dim + + // From `image_size` in config.json. + let mut image = image.resize_image([384, 384])?; + normalize_image(image.slice_mut(0)); + + let encoded_image: NdTensor = encoder_model + .run_one(image.view().into(), None)? + .try_into()?; + + let encoder_hidden_states_id = decoder_model.node_id("encoder_hidden_states")?; + + // `decoder_start_token_id` from `generation_config.json`. This is the `` + // token. + let decoder_start_token = 2; + let eos_token = 2; + + let max_tokens = 100; + + let prompt = vec![decoder_start_token]; + let generator = Generator::from_model(&decoder_model)? + .with_prompt(&prompt) + .with_constant_input(encoder_hidden_states_id, encoded_image.view().into()) + .stop_on_tokens([eos_token]) + .take(max_tokens) + .decode(&tokenizer); + + for token in generator { + let token = token?; + + print!("{}", token); + let _ = std::io::stdout().flush(); + } + + Ok(()) +} From 18e9b2a866de25991d84140ebce8c6a2b18a1217 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 23 Aug 2024 08:59:50 +0100 Subject: [PATCH 3/3] Allow BPE tokens containing unused tokens in vocab Support loading tokenizers which contain entries in the `vocab` map that do not appear in either `merges` or `added_tokens`. The TrOCR base model on Hugging Face (https://huggingface.co/microsoft/trocr-base-printed) has an `<|endoftext|>` token in the vocab which does not appear in the `merges` or `added_tokens` fields. --- rten-text/src/tokenizers/bpe.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/rten-text/src/tokenizers/bpe.rs b/rten-text/src/tokenizers/bpe.rs index 8fe41974..d98b55ab 100644 --- a/rten-text/src/tokenizers/bpe.rs +++ b/rten-text/src/tokenizers/bpe.rs @@ -302,8 +302,6 @@ impl Bpe { if let Some(rank) = builder.get_token_rank(&token) { rank_to_token_id.insert(rank, id); - } else if !added_tokens.values().any(|s| *s == token.as_str()) { - return Err(BpeError::InvalidVocabEntry(token)); } } (Some(rank_to_token_id), Some(token_id_to_encoded_bytes))