Skip to content

Commit

Permalink
Add SAM2 and ONNX (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon authored Aug 1, 2024
1 parent 451aa8c commit 46a4456
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
14 changes: 14 additions & 0 deletions examples/sam/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.with_model("sam-vit-b-decoder-u8.onnx")?;
(options_encoder, options_decoder, "SAM")
}
SamKind::Sam2 => {
let options_encoder = Options::default()
// .with_model("sam2-hiera-tiny-encoder.onnx")?;
// .with_model("sam2-hiera-small-encoder.onnx")?;
.with_model("sam2-hiera-base-plus-encoder.onnx")?;
let options_decoder = Options::default()
.with_i31((1, 1, 1).into())
.with_i41((1, 1, 1).into())
.with_sam_kind(SamKind::Sam2)
// .with_model("sam2-hiera-tiny-decoder.onnx")?;
// .with_model("sam2-hiera-small-decoder.onnx")?;
.with_model("sam2-hiera-base-plus-decoder.onnx")?;
(options_encoder, options_decoder, "SAM2")
}
SamKind::MobileSam => {
let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?;

Expand Down
38 changes: 35 additions & 3 deletions src/models/sam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y};
#[derive(Debug, Clone, clap::ValueEnum)]
pub enum SamKind {
Sam,
Sam2,
MobileSam,
SamHq,
EdgeSam,
Expand Down Expand Up @@ -94,7 +95,7 @@ impl SAM {
SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => {
options_decoder.use_low_res_mask.unwrap_or(false)
}
SamKind::EdgeSam => true,
SamKind::EdgeSam | SamKind::Sam2 => true,
};

encoder.dry_run()?;
Expand Down Expand Up @@ -142,9 +143,13 @@ impl SAM {
xs0: &[DynamicImage],
prompts: &[SamPrompt],
) -> Result<Vec<Y>> {
let mut ys: Vec<Y> = Vec::new();
let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind {
SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])),
_ => (&xs[0], None, None),
};

for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() {
let mut ys: Vec<Y> = Vec::new();
for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() {
let image_width = xs0[idx].width() as f32;
let image_height = xs0[idx].height() as f32;
let ratio =
Expand Down Expand Up @@ -180,6 +185,32 @@ impl SAM {
prompts[idx].point_labels()?,
]
}
SamKind::Sam2 => {
vec![
X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?,
X::from(
high_res_features_0
.unwrap()
.slice(s![idx, .., .., ..])
.into_dyn()
.into_owned(),
)
.insert_axis(0)?,
X::from(
high_res_features_1
.unwrap()
.slice(s![idx, .., .., ..])
.into_dyn()
.into_owned(),
)
.insert_axis(0)?,
prompts[idx].point_coords(ratio)?,
prompts[idx].point_labels()?,
X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input
X::zeros(&[1]), // has_mask_input
X::from(vec![image_height, image_width]), // orig_im_size
]
}
};

let ys_ = self.decoder.run(args)?;
Expand All @@ -196,6 +227,7 @@ impl SAM {
(&ys_[2], &ys_[1])
}
}
SamKind::Sam2 => (&ys_[0], &ys_[1]),
SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) {
(2, 4) => (&ys_[1], &ys_[0]),
(4, 2) => (&ys_[0], &ys_[1]),
Expand Down

0 comments on commit 46a4456

Please sign in to comment.