Skip to content

Commit

Permalink
Add Segment Anything example
Browse files Browse the repository at this point in the history
This takes an image and a query point as input and outputs a segmentation mask.
  • Loading branch information
robertknight committed Jul 29, 2024
1 parent 3980090 commit 44db797
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 0 deletions.
4 changes: 4 additions & 0 deletions rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ path = "src/yolo.rs"
name = "depth_anything"
path = "src/depth_anything.rs"

[[bin]]
name = "segment_anything"
path = "src/segment_anything.rs"

# Text
[[bin]]
name = "bert_qa"
Expand Down
1 change: 1 addition & 0 deletions rten-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The examples have been chosen to cover common tasks and popular models.
- **depth_anything** - Monocular depth estimation using [Depth Anything](https://github.com/LiheYoung/Depth-Anything)
- **detr** - Object detection using [DETR](https://research.facebook.com/publications/end-to-end-object-detection-with-transformers/)
- **distilvit** - Image captioning using [Mozilla's DistilViT](https://hacks.mozilla.org/2024/05/experimenting-with-local-alt-text-generation-in-firefox-nightly/)
- **segment_anything** - Image segmentation using [Segment Anything](https://segment-anything.com)
- **yolo** - Object detection using [YOLO v8](https://github.com/ultralytics/ultralytics)

### Text
Expand Down
198 changes: 198 additions & 0 deletions rten-examples/src/segment_anything.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use std::collections::VecDeque;
use std::error::Error;

use rten::{Dimension, FloatOperators, Model};
use rten_imageio::{read_image, write_image};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

struct Args {
encoder_model: String,
decoder_model: String,

/// Path to input image to segment.
image: String,

/// (x, y) query points identifying the object(s) to generate segmentation
/// masks for.
points: Vec<(u32, u32)>,
}

fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();

while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Long("help") => {
println!(
"Segment an image.
Usage: {bin_name} <encoder_model> <decoder_model> <image> <points>
Args:
<encoder_model> - Image encoder model
<decoder_model> - Prompt decoder model
<image> - Image to process
<points> -
List of points identifying the object to segment.
This has the form `x1,y1;x2,y2;...`. At least one point must be provided.
",
bin_name = parser.bin_name().unwrap_or("segment_anything")
);
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}

let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?;
let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?;
let image = values.pop_front().ok_or("missing `image` arg")?;
let points_str = values.pop_front().ok_or("missing `points` arg")?;

let mut points: Vec<(u32, u32)> = Vec::new();
for xy_str in points_str.split(";") {
let Some(xy_coords) = xy_str.trim().split_once(",") else {
return Err(lexopt::Error::Custom(
"points should be x,y coordinate pairs".into(),
));
};
let (Ok(x), Ok(y)) = (xy_coords.0.parse(), xy_coords.1.parse()) else {
return Err(lexopt::Error::Custom(
"points should be positive integer values".into(),
));
};
points.push((x, y));
}

let args = Args {
image,
encoder_model,
decoder_model,
points,
};

Ok(args)
}

/// Perform image segmentation using Segment Anything [^1].
///
/// First export the ONNX model using Hugging Face's Optimum tool:
///
/// ```
/// optimum-cli export onnx --model facebook/sam-vit-base sam-vit-base
/// ```
///
/// Then convert the models to `.rten` format and run the demo, specifying a
/// path to the image to segment and one or more points in the image identifying
/// the object of interest.
///
/// ```
/// rten-convert sam-vit-base/vision_encoder.onnx
/// rten-convert sam-vit-base/prompt_encoder_mask_decoder.rten
/// cargo run --release --bin segment_anything sam-vit-base/vision_encoder.rten sam-vit-base/prompt_encoder_mask_decoder.rten image.jpg points
/// ```
///
/// Where `points` is a semi-colon separated list of x,y point pairs identifying
/// the objects to segment. For example `200,300;205,305` generates a
/// segmentation mask for the object containing the points (200, 300) and (205,
/// 305).
///
/// [^1]: https://segment-anything.com
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;

println!("Loading model...");
let encoder_model = Model::load_file(args.encoder_model)?;
let decoder_model = Model::load_file(args.decoder_model)?;

println!("Reading image...");
let mut image: Tensor = read_image(&args.image)?.into();
let image_h = image.size(1);
let image_w = image.size(2);
image.insert_axis(0);

// Generate image embeddings.
let pixel_values_id = encoder_model.node_id("pixel_values")?;
let [input_h, input_w] = match encoder_model
.node_info(pixel_values_id)
.and_then(|ni| ni.shape())
.as_deref()
{
Some(&[_, _, Dimension::Fixed(h), Dimension::Fixed(w)]) => [h, w],
_ => [1024, 1024],
};
let image = image.resize_image([input_h, input_w])?;

println!("Generating image embedding...");
let image_embeddings_id = encoder_model.node_id("image_embeddings")?;
let image_pos_embeddings_id = encoder_model.node_id("image_positional_embeddings")?;

// TODO - Clarify normalization required for `pixel_values`.
let [image_embeddings, image_pos_embeddings] = encoder_model.run_n(
vec![(pixel_values_id, image.view().into())],
[image_embeddings_id, image_pos_embeddings_id],
None,
)?;

println!("Segmenting image with {} points...", args.points.len());

// Prepare decoder inputs.
let input_points_id = decoder_model.node_id("input_points")?;
let input_labels_id = decoder_model.node_id("input_labels")?;
let decoder_embeddings_id = decoder_model.node_id("image_embeddings")?;
let decoder_pos_embeddings_id = decoder_model.node_id("image_positional_embeddings")?;

let iou_scores_id = decoder_model.node_id("iou_scores")?;
let pred_masks_id = decoder_model.node_id("pred_masks")?;

let h_scale = input_h as f32 / image_h as f32;
let w_scale = input_w as f32 / image_w as f32;

let point_batch = 1;
let nb_points_per_image = args.points.len();
let input_points = NdTensor::from_fn(
[1, point_batch, nb_points_per_image, 2],
|[_, _, point, coord]| {
if coord == 0 {
args.points[point].0 as f32 * w_scale
} else {
args.points[point].1 as f32 * h_scale
}
},
);

const MATCH_POINT: i32 = 1;
const _NON_MATCH_POINT: i32 = 0;
const _BACKGROUND_POINT: i32 = -1;
let input_labels = NdTensor::<i32, 3>::full([1, point_batch, nb_points_per_image], MATCH_POINT);

// Run decoder and generate segmentation masks.
let [_iou_scores, pred_masks] = decoder_model.run_n(
vec![
(input_points_id, input_points.into()),
(input_labels_id, input_labels.into()),
(decoder_embeddings_id, image_embeddings.into()),
(decoder_pos_embeddings_id, image_pos_embeddings.into()),
],
[iou_scores_id, pred_masks_id],
None,
)?;

let pred_masks: NdTensor<f32, 5> = pred_masks.try_into()?;
let [_batch, _point_batch, _mask, mask_h, mask_w] = pred_masks.shape();
let best_mask = pred_masks
.slice::<2, _>((0, 0, 0))
.reshaped([1, 1, mask_h, mask_w]);
let resized_mask = best_mask.resize_image([image_h, image_w])?;
write_image("segmented.png", resized_mask.slice::<3, _>(0).nd_view())?;

Ok(())
}

0 comments on commit 44db797

Please sign in to comment.