-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
97a473c
commit c7e5d7b
Showing
3 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
use std::collections::VecDeque; | ||
use std::error::Error; | ||
use std::fs; | ||
use std::io::prelude::*; | ||
|
||
use rten::{FloatOperators, Model}; | ||
use rten_generate::{Generator, GeneratorUtils}; | ||
use rten_imageio::read_image; | ||
use rten_tensor::prelude::*; | ||
use rten_tensor::{NdTensor, NdTensorViewMut}; | ||
use rten_text::tokenizers::Tokenizer; | ||
|
||
struct Args { | ||
encoder_model: String, | ||
decoder_model: String, | ||
tokenizer_config: String, | ||
image_path: String, | ||
} | ||
|
||
fn parse_args() -> Result<Args, lexopt::Error> { | ||
use lexopt::prelude::*; | ||
|
||
let mut values = VecDeque::new(); | ||
let mut parser = lexopt::Parser::from_env(); | ||
|
||
while let Some(arg) = parser.next()? { | ||
match arg { | ||
Value(val) => values.push_back(val.string()?), | ||
Long("help") => { | ||
println!( | ||
"Generate a caption for an image. | ||
Usage: {bin_name} [options] <encoder_model> <decoder_model> <tokenizer> <image> | ||
Args: | ||
<encoder_model> - Image encoder model | ||
<decoder_model> - Text decoder model | ||
<tokenizer> - `tokenizer.json` file | ||
<image> - Image path | ||
", | ||
bin_name = parser.bin_name().unwrap_or("distilvit") | ||
); | ||
std::process::exit(0); | ||
} | ||
_ => return Err(arg.unexpected()), | ||
} | ||
} | ||
|
||
let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?; | ||
let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?; | ||
let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; | ||
let image_path = values.pop_front().ok_or("missing `image_path` arg")?; | ||
|
||
let args = Args { | ||
encoder_model, | ||
decoder_model, | ||
tokenizer_config, | ||
image_path, | ||
}; | ||
|
||
Ok(args) | ||
} | ||
|
||
fn normalize_pixel(value: f32, channel: usize) -> f32 { | ||
assert!(channel < 3, "channel index is invalid"); | ||
|
||
// Values taken from `preprocessor_config.json`. | ||
let mean = [0.5, 0.5, 0.5]; | ||
let std_dev = [0.5, 0.5, 0.5]; | ||
|
||
(value - mean[channel]) / std_dev[channel] | ||
} | ||
|
||
fn normalize_image(mut img: NdTensorViewMut<f32, 3>) { | ||
for ([chan, _y, _x], pixel) in img.indices().zip(img.iter_mut()) { | ||
*pixel = normalize_pixel(*pixel, chan); | ||
} | ||
} | ||
|
||
/// Recognize text line images using TrOCR [^1]. | ||
/// | ||
/// First use Hugging Face's Optimum tool to download and export the models to | ||
/// ONNX: | ||
/// | ||
/// ``` | ||
/// optimum-cli export onnx --model microsoft/trocr-base-printed trocr-base-printed | ||
/// ``` | ||
/// | ||
/// Convert the models to `.rten` format: | ||
/// | ||
/// ``` | ||
/// rten-convert trocr-base-printed/encoder_model.onnx | ||
/// rten-convert trocr-base-printed/decoder_model.onnx | ||
/// ``` | ||
/// | ||
/// Run the model, specifying the image to recognize: | ||
/// | ||
/// ```sh | ||
/// cargo run --release --bin trocr trocr-base-printed/encoder_model.rten trocr-base-printed/decoder_model.rten tokenizer.json <image> | ||
/// ``` | ||
/// | ||
/// [^1]: https://arxiv.org/abs/2109.10282 | ||
fn main() -> Result<(), Box<dyn Error>> { | ||
let args = parse_args()?; | ||
let encoder_model = unsafe { Model::load_mmap(args.encoder_model)? }; | ||
let decoder_model = unsafe { Model::load_mmap(args.decoder_model)? }; | ||
let tokenizer_config = fs::read_to_string(&args.tokenizer_config)?; | ||
let tokenizer = Tokenizer::from_json(&tokenizer_config)?; | ||
let mut image = read_image(args.image_path)?.into_dyn(); | ||
image.insert_axis(0); // Add batch dim | ||
let mut image = image.resize_image([384, 384])?; | ||
normalize_image(image.slice_mut(0)); | ||
|
||
let encoded_image: NdTensor<f32, 3> = encoder_model | ||
.run_one(image.view().into(), None)? | ||
.try_into()?; | ||
|
||
let encoder_hidden_states_id = decoder_model.node_id("encoder_hidden_states")?; | ||
|
||
// `decoder_start_token_id` from `generation_config.json`. This is the `</s>` | ||
// token. | ||
let decoder_start_token = 2; | ||
let eos_token = 2; | ||
|
||
let max_tokens = 100; | ||
|
||
let prompt = vec![decoder_start_token]; | ||
let generator = Generator::from_model(&decoder_model)? | ||
.with_prompt(&prompt) | ||
.with_constant_input(encoder_hidden_states_id, encoded_image.view().into()) | ||
.stop_on_tokens([eos_token]) | ||
.take(max_tokens) | ||
.decode(&tokenizer); | ||
|
||
for token in generator { | ||
let token = token?; | ||
|
||
print!("{}", token); | ||
let _ = std::io::stdout().flush(); | ||
} | ||
|
||
Ok(()) | ||
} |