Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial implementation of Einsum and Segment Anything example #295

Merged
merged 4 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,10 @@ def op_node_from_onnx_operator(
op_reader.check_attr("exclusive", "int", 0)
op_reader.check_attr("reverse", "int", 0)

case "Einsum":
attrs = sg.EinsumAttrsT()
attrs.equation = op_reader.require_attr("equation", "string")

case "Elu":
attrs = sg.EluAttrsT()
attrs.alpha = op_reader.get_attr("alpha", "float", 1.0)
Expand Down
86 changes: 85 additions & 1 deletion rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class OperatorType(object):
Softplus = 100
GatherND = 101
Gelu = 102
Einsum = 103


class RNNDirection(object):
Expand Down Expand Up @@ -185,6 +186,7 @@ class OperatorAttrs(object):
RandomNormalLikeAttrs = 35
GatherNDAttrs = 36
GeluAttrs = 37
EinsumAttrs = 38

def OperatorAttrsCreator(unionType, table):
from flatbuffers.table import Table
Expand Down Expand Up @@ -264,6 +266,8 @@ def OperatorAttrsCreator(unionType, table):
return GatherNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GeluAttrs:
return GeluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EinsumAttrs:
return EinsumAttrsT.InitFromBuf(table.Bytes, table.Pos)
return None


Expand Down Expand Up @@ -1588,6 +1592,86 @@ def Pack(self, builder):
return convTransposeAttrs


class EinsumAttrs(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = EinsumAttrs()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsEinsumAttrs(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
@classmethod
def EinsumAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed)

# EinsumAttrs
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# EinsumAttrs
def Equation(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

def EinsumAttrsStart(builder):
builder.StartObject(1)

def EinsumAttrsAddEquation(builder, equation):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(equation), 0)

def EinsumAttrsEnd(builder):
return builder.EndObject()



class EinsumAttrsT(object):

# EinsumAttrsT
def __init__(self):
self.equation = None # type: str

@classmethod
def InitFromBuf(cls, buf, pos):
einsumAttrs = EinsumAttrs()
einsumAttrs.Init(buf, pos)
return cls.InitFromObj(einsumAttrs)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, einsumAttrs):
x = EinsumAttrsT()
x._UnPack(einsumAttrs)
return x

# EinsumAttrsT
def _UnPack(self, einsumAttrs):
if einsumAttrs is None:
return
self.equation = einsumAttrs.Equation()

# EinsumAttrsT
def Pack(self, builder):
if self.equation is not None:
equation = builder.CreateString(self.equation)
EinsumAttrsStart(builder)
if self.equation is not None:
EinsumAttrsAddEquation(builder, equation)
einsumAttrs = EinsumAttrsEnd(builder)
return einsumAttrs


class EluAttrs(object):
__slots__ = ['_tab']

Expand Down Expand Up @@ -4585,7 +4669,7 @@ class OperatorNodeT(object):
def __init__(self):
self.type = 0 # type: int
self.attrsType = 0 # type: int
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT]
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT]
self.inputs = None # type: List[int]
self.outputs = None # type: List[int]

Expand Down
4 changes: 4 additions & 0 deletions rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ path = "src/yolo.rs"
name = "depth_anything"
path = "src/depth_anything.rs"

[[bin]]
name = "segment_anything"
path = "src/segment_anything.rs"

# Text
[[bin]]
name = "bert_qa"
Expand Down
1 change: 1 addition & 0 deletions rten-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The examples have been chosen to cover common tasks and popular models.
- **depth_anything** - Monocular depth estimation using [Depth Anything](https://github.com/LiheYoung/Depth-Anything)
- **detr** - Object detection using [DETR](https://research.facebook.com/publications/end-to-end-object-detection-with-transformers/)
- **distilvit** - Image captioning using [Mozilla's DistilViT](https://hacks.mozilla.org/2024/05/experimenting-with-local-alt-text-generation-in-firefox-nightly/)
- **segment_anything** - Image segmentation using [Segment Anything](https://segment-anything.com)
- **yolo** - Object detection using [YOLO v8](https://github.com/ultralytics/ultralytics)

### Text
Expand Down
217 changes: 217 additions & 0 deletions rten-examples/src/segment_anything.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
use std::collections::VecDeque;
use std::error::Error;

use rten::{Dimension, FloatOperators, Model};
use rten_imageio::{read_image, write_image};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

struct Args {
/// Path to image encoder model.
encoder_model: String,

/// Path to prompt encoder / mask decoder model.
decoder_model: String,

/// Path to input image to segment.
image: String,

/// (x, y) query points identifying the object(s) to generate segmentation
/// masks for.
points: Vec<(u32, u32)>,
}

fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;

let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();

while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Long("help") => {
println!(
"Segment an image.

Usage: {bin_name} <encoder_model> <decoder_model> <image> <points>

Args:

<encoder_model> - Image encoder model
<decoder_model> - Prompt decoder model
<image> - Image to process
<points> -

List of points identifying the object to segment.

This has the form `x1,y1;x2,y2;...`. At least one point must be provided.
",
bin_name = parser.bin_name().unwrap_or("segment_anything")
);
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}

let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?;
let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?;
let image = values.pop_front().ok_or("missing `image` arg")?;
let points_str = values.pop_front().ok_or("missing `points` arg")?;

let mut points: Vec<(u32, u32)> = Vec::new();
for xy_str in points_str.split(";") {
let Some(xy_coords) = xy_str.trim().split_once(",") else {
return Err(lexopt::Error::Custom(
"points should be x,y coordinate pairs".into(),
));
};
let (Ok(x), Ok(y)) = (xy_coords.0.parse(), xy_coords.1.parse()) else {
return Err(lexopt::Error::Custom(
"points should be positive integer values".into(),
));
};
points.push((x, y));
}

let args = Args {
image,
encoder_model,
decoder_model,
points,
};

Ok(args)
}

/// Perform image segmentation using Segment Anything [^1].
///
/// First export the ONNX model using Hugging Face's Optimum tool:
///
/// ```
/// optimum-cli export onnx --model facebook/sam-vit-base sam-vit-base
/// ```
///
/// Then convert the models to `.rten` format and run the demo, specifying a
/// path to the image to segment and one or more points in the image identifying
/// the object of interest.
///
/// ```
/// rten-convert sam-vit-base/vision_encoder.onnx
/// rten-convert sam-vit-base/prompt_encoder_mask_decoder.rten
/// cargo run --release --bin segment_anything sam-vit-base/vision_encoder.rten sam-vit-base/prompt_encoder_mask_decoder.rten image.jpg points
/// ```
///
/// Where `points` is a semi-colon separated list of x,y pixel coordinates
/// identifying the objects to segment. For example `200,300;205,305` generates
/// a segmentation mask for the object containing the points (200, 300) and
/// (205, 305). At least one point must be specified.
///
/// ## Alternative models
///
/// The original SAM model uses a computationally expensive vision encoder
/// paired with a lightweight prompt decoder. Since its release various teams
/// have created alternatives with faster image encoders. For faster generation
/// of image embeddings, you can try alternatives such as:
///
/// - [SlimSAM](https://huggingface.co/Zigeng/SlimSAM-uniform-50)
///
/// The process for exporting and converting the models is the same as for
/// the `facebook/sam-vit-base` model.
///
/// [^1]: https://segment-anything.com
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;

println!("Loading model...");
let encoder_model = Model::load_file(args.encoder_model)?;
let decoder_model = Model::load_file(args.decoder_model)?;

println!("Reading image...");
let mut image: Tensor = read_image(&args.image)?.into();
let image_h = image.size(1);
let image_w = image.size(2);
image.insert_axis(0);

// Prepare the input image.
//
// This currently does the mandatory resizing of the input image, but
// doesn't normalize the pixel values.
let pixel_values_id = encoder_model.node_id("pixel_values")?;
let [input_h, input_w] = match encoder_model
.node_info(pixel_values_id)
.and_then(|ni| ni.shape())
.as_deref()
{
Some(&[_, _, Dimension::Fixed(h), Dimension::Fixed(w)]) => [h, w],
_ => [1024, 1024],
};
let image = image.resize_image([input_h, input_w])?;

// Generate image embeddings.
println!("Generating image embedding...");
let image_embeddings_id = encoder_model.node_id("image_embeddings")?;
let image_pos_embeddings_id = encoder_model.node_id("image_positional_embeddings")?;

let [image_embeddings, image_pos_embeddings] = encoder_model.run_n(
vec![(pixel_values_id, image.view().into())],
[image_embeddings_id, image_pos_embeddings_id],
None,
)?;

println!("Segmenting image with {} points...", args.points.len());

// Prepare decoder inputs.
let input_points_id = decoder_model.node_id("input_points")?;
let input_labels_id = decoder_model.node_id("input_labels")?;
let decoder_embeddings_id = decoder_model.node_id("image_embeddings")?;
let decoder_pos_embeddings_id = decoder_model.node_id("image_positional_embeddings")?;

let iou_scores_id = decoder_model.node_id("iou_scores")?;
let pred_masks_id = decoder_model.node_id("pred_masks")?;

let h_scale = input_h as f32 / image_h as f32;
let w_scale = input_w as f32 / image_w as f32;

let point_batch = 1;
let nb_points_per_image = args.points.len();
let input_points = NdTensor::from_fn(
[1, point_batch, nb_points_per_image, 2],
|[_, _, point, coord]| {
if coord == 0 {
args.points[point].0 as f32 * w_scale
} else {
args.points[point].1 as f32 * h_scale
}
},
);

const MATCH_POINT: i32 = 1;
const _NON_MATCH_POINT: i32 = 0;
const _BACKGROUND_POINT: i32 = -1;
let input_labels = NdTensor::<i32, 3>::full([1, point_batch, nb_points_per_image], MATCH_POINT);

// Run decoder and generate segmentation masks.
let [_iou_scores, pred_masks] = decoder_model.run_n(
vec![
(input_points_id, input_points.into()),
(input_labels_id, input_labels.into()),
(decoder_embeddings_id, image_embeddings.into()),
(decoder_pos_embeddings_id, image_pos_embeddings.into()),
],
[iou_scores_id, pred_masks_id],
None,
)?;

// Resize the output mask to match the original image and save to disk.
let pred_masks: NdTensor<f32, 5> = pred_masks.try_into()?;
let [_batch, _point_batch, _mask, mask_h, mask_w] = pred_masks.shape();
let best_mask = pred_masks
.slice::<2, _>((0, 0, 0))
.reshaped([1, 1, mask_h, mask_w]);
let resized_mask = best_mask.resize_image([image_h, image_w])?;
write_image("segmented.png", resized_mask.slice::<3, _>(0).nd_view())?;

Ok(())
}
Loading
Loading