Skip to content

Commit

Permalink
Run Ocrs parallel recognition in RTen's thread pool
Browse files Browse the repository at this point in the history
RTen model execution is multi-threaded by default. To avoid contention between
threads in the Rayon global thread pool and RTen's thread pool, run parallel
recognition over batches of line images in RTen's thread pool.
  • Loading branch information
robertknight committed May 25, 2024
1 parent ce42a95 commit 71f6176
Showing 1 changed file with 49 additions and 46 deletions.
95 changes: 49 additions & 46 deletions ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashMap;
use anyhow::anyhow;
use rayon::prelude::*;
use rten::ctc::{CtcDecoder, CtcHypothesis};
use rten::{Dimension, FloatOperators, Model, NodeId};
use rten::{thread_pool, Dimension, FloatOperators, Model, NodeId};
use rten_imageproc::{
bounding_rect, BoundingRect, Line, Point, PointF, Polygon, Rect, RotatedRect,
};
Expand Down Expand Up @@ -473,53 +473,56 @@ impl TextRecognizer {
.collect();

// Run text recognition on batches of lines.
let batch_rec_results: Result<Vec<Vec<LineRecResult>>, ModelRunError> = line_groups
.into_par_iter()
.map(|(group_width, lines)| {
if debug {
println!(
"Processing group of {} lines of width {}",
lines.len(),
group_width,
);
}

let rec_input = prepare_text_line_batch(
&image,
&lines,
page_rect,
rec_img_height as usize,
group_width as usize,
);

let rec_output = self.run(rec_input)?;
let ctc_input_len = rec_output.shape()[1];

// Apply CTC decoding to get the label sequence for each line.
let line_rec_results = lines
.into_iter()
.enumerate()
.map(|(group_line_index, line)| {
let decoder = CtcDecoder::new();
let input_seq = rec_output.slice([group_line_index]);
let ctc_output = match decode_method {
DecodeMethod::Greedy => decoder.decode_greedy(input_seq),
DecodeMethod::BeamSearch { width } => {
decoder.decode_beam(input_seq, width)
}
};
LineRecResult {
line,
rec_input_len: group_width as usize,
ctc_input_len,
ctc_output,
let batch_rec_results: Result<Vec<Vec<LineRecResult>>, ModelRunError> =
thread_pool().run(|| {
line_groups
.into_par_iter()
.map(|(group_width, lines)| {
if debug {
println!(
"Processing group of {} lines of width {}",
lines.len(),
group_width,
);
}
})
.collect::<Vec<_>>();

Ok(line_rec_results)
})
.collect();
let rec_input = prepare_text_line_batch(
&image,
&lines,
page_rect,
rec_img_height as usize,
group_width as usize,
);

let rec_output = self.run(rec_input)?;
let ctc_input_len = rec_output.shape()[1];

// Apply CTC decoding to get the label sequence for each line.
let line_rec_results = lines
.into_iter()
.enumerate()
.map(|(group_line_index, line)| {
let decoder = CtcDecoder::new();
let input_seq = rec_output.slice([group_line_index]);
let ctc_output = match decode_method {
DecodeMethod::Greedy => decoder.decode_greedy(input_seq),
DecodeMethod::BeamSearch { width } => {
decoder.decode_beam(input_seq, width)
}
};
LineRecResult {
line,
rec_input_len: group_width as usize,
ctc_input_len,
ctc_output,
}
})
.collect::<Vec<_>>();

Ok(line_rec_results)
})
.collect()
});

let mut line_rec_results: Vec<LineRecResult> =
batch_rec_results?.into_iter().flatten().collect();
Expand Down

0 comments on commit 71f6176

Please sign in to comment.