-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #304 from robertknight/trocr-example
Add TrOCR example
- Loading branch information
Showing
5 changed files
with
159 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
use std::collections::VecDeque; | ||
use std::error::Error; | ||
use std::fs; | ||
use std::io::prelude::*; | ||
|
||
use rten::{FloatOperators, Model}; | ||
use rten_generate::{Generator, GeneratorUtils}; | ||
use rten_imageio::read_image; | ||
use rten_tensor::prelude::*; | ||
use rten_tensor::{NdTensor, NdTensorViewMut}; | ||
use rten_text::tokenizers::Tokenizer; | ||
|
||
struct Args { | ||
encoder_model: String, | ||
decoder_model: String, | ||
tokenizer_config: String, | ||
image_path: String, | ||
} | ||
|
||
fn parse_args() -> Result<Args, lexopt::Error> { | ||
use lexopt::prelude::*; | ||
|
||
let mut values = VecDeque::new(); | ||
let mut parser = lexopt::Parser::from_env(); | ||
|
||
while let Some(arg) = parser.next()? { | ||
match arg { | ||
Value(val) => values.push_back(val.string()?), | ||
Long("help") => { | ||
println!( | ||
"Read text from an image containing a single text line. | ||
Usage: {bin_name} [options] <encoder_model> <decoder_model> <tokenizer> <image> | ||
Args: | ||
<encoder_model> - Image encoder model | ||
<decoder_model> - Text decoder model | ||
<tokenizer> - `tokenizer.json` file | ||
<image> - Image path | ||
", | ||
bin_name = parser.bin_name().unwrap_or("distilvit") | ||
); | ||
std::process::exit(0); | ||
} | ||
_ => return Err(arg.unexpected()), | ||
} | ||
} | ||
|
||
let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?; | ||
let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?; | ||
let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; | ||
let image_path = values.pop_front().ok_or("missing `image_path` arg")?; | ||
|
||
let args = Args { | ||
encoder_model, | ||
decoder_model, | ||
tokenizer_config, | ||
image_path, | ||
}; | ||
|
||
Ok(args) | ||
} | ||
|
||
fn normalize_pixel(value: f32, channel: usize) -> f32 { | ||
assert!(channel < 3, "channel index is invalid"); | ||
|
||
// Values taken from `preprocessor_config.json`. | ||
let mean = [0.5, 0.5, 0.5]; | ||
let std_dev = [0.5, 0.5, 0.5]; | ||
|
||
(value - mean[channel]) / std_dev[channel] | ||
} | ||
|
||
fn normalize_image(mut img: NdTensorViewMut<f32, 3>) { | ||
for chan in 0..img.size(0) { | ||
img.slice_mut::<2, _>(chan) | ||
.apply(|x| normalize_pixel(*x, chan)); | ||
} | ||
} | ||
|
||
/// Recognize text line images using TrOCR [^1]. | ||
/// | ||
/// First use Hugging Face's Optimum tool to download and export the models to | ||
/// ONNX: | ||
/// | ||
/// ``` | ||
/// optimum-cli export onnx --model microsoft/trocr-base-printed trocr-base-printed | ||
/// ``` | ||
/// | ||
/// Convert the models to `.rten` format. For the decoder you need to use the | ||
/// "merged" model. | ||
/// | ||
/// ``` | ||
/// rten-convert trocr-base-printed/encoder_model.onnx | ||
/// rten-convert trocr-base-printed/decoder_model_merged.onnx | ||
/// ``` | ||
/// | ||
/// Run the model, specifying the image to recognize: | ||
/// | ||
/// ```sh | ||
/// cargo run --release --bin trocr trocr-base-printed/encoder_model.rten trocr-base-printed/decoder_model_merged.rten tokenizer.json <image> | ||
/// ``` | ||
/// | ||
/// [^1]: https://arxiv.org/abs/2109.10282 | ||
fn main() -> Result<(), Box<dyn Error>> { | ||
let args = parse_args()?; | ||
let encoder_model = unsafe { Model::load_mmap(args.encoder_model)? }; | ||
let decoder_model = unsafe { Model::load_mmap(args.decoder_model)? }; | ||
let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?; | ||
let tokenizer = Tokenizer::from_json(&tokenizer_config)?; | ||
let mut image = read_image(args.image_path)?.into_dyn(); | ||
image.insert_axis(0); // Add batch dim | ||
|
||
// From `image_size` in config.json. | ||
let mut image = image.resize_image([384, 384])?; | ||
normalize_image(image.slice_mut(0)); | ||
|
||
let encoded_image: NdTensor<f32, 3> = encoder_model | ||
.run_one(image.view().into(), None)? | ||
.try_into()?; | ||
|
||
let encoder_hidden_states_id = decoder_model.node_id("encoder_hidden_states")?; | ||
|
||
// `decoder_start_token_id` from `generation_config.json`. This is the `</s>` | ||
// token. | ||
let decoder_start_token = 2; | ||
let eos_token = 2; | ||
|
||
let max_tokens = 100; | ||
|
||
let prompt = vec![decoder_start_token]; | ||
let generator = Generator::from_model(&decoder_model)? | ||
.with_prompt(&prompt) | ||
.with_constant_input(encoder_hidden_states_id, encoded_image.view().into()) | ||
.stop_on_tokens([eos_token]) | ||
.take(max_tokens) | ||
.decode(&tokenizer); | ||
|
||
for token in generator { | ||
let token = token?; | ||
|
||
print!("{}", token); | ||
let _ = std::io::stdout().flush(); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters