Skip to content

Commit

Permalink
Add Xs
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Aug 3, 2024
1 parent a78b8d7 commit 10990c2
Show file tree
Hide file tree
Showing 17 changed files with 92 additions and 175 deletions.
7 changes: 4 additions & 3 deletions src/core/ort_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ impl OrtEngine {
let x: Array<f32, IxDyn> = Array::ones(x).into_dyn();
xs.push(X::from(x));
}
let xs = Xs::from(xs);
for _ in 0..self.num_dry_run {
// self.run(xs.as_ref())?;
self.run(xs.clone())?;
Expand All @@ -298,11 +299,11 @@ impl OrtEngine {
Ok(())
}

pub fn run(&mut self, xs: Vec<X>) -> Result<Xs> {
pub fn run(&mut self, xs: Xs) -> Result<Xs> {
// inputs dtype alignment
let mut xs_ = Vec::new();
let t_pre = std::time::Instant::now();
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.iter()) {
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
let x_ = match &idtype {
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float16 => {
Expand Down Expand Up @@ -358,7 +359,7 @@ impl OrtEngine {
.into_owned(),
_ => todo!(),
};
ys.add(name.as_str(), X::from(y_))?;
ys.push_kv(name.as_str(), X::from(y_))?;
}
let t_post = t_post.elapsed();
self.ts.add_or_push(2, t_post);
Expand Down
6 changes: 3 additions & 3 deletions src/core/vision.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Options, Xs, X, Y};
use crate::{Options, Xs, Y};

pub trait Vision: Sized {
type Input; // DynamicImage
Expand All @@ -7,10 +7,10 @@ pub trait Vision: Sized {
fn new(options: Options) -> anyhow::Result<Self>;

/// Preprocesses the input data.
fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result<Vec<X>>;
fn preprocess(&self, xs: &[Self::Input]) -> anyhow::Result<Xs>;

/// Executes the model on the preprocessed data.
fn inference(&mut self, xs: Vec<X>) -> anyhow::Result<Xs>;
fn inference(&mut self, xs: Xs) -> anyhow::Result<Xs>;

/// Postprocesses the model's output.
fn postprocess(&self, xs: Xs, xs0: &[Self::Input]) -> anyhow::Result<Vec<Y>>;
Expand Down
60 changes: 53 additions & 7 deletions src/core/xs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,30 @@ use anyhow::Result;
use std::collections::HashMap;
use std::ops::{Deref, Index};

use crate::X;
use crate::{string_random, X};

#[derive(Debug, Default, Clone)]
pub struct Xs {
names: Vec<String>,
map: HashMap<String, X>,
names: Vec<String>,
}

impl From<X> for Xs {
fn from(x: X) -> Self {
let mut xs = Self::default();
xs.push(x);
xs
}
}

impl From<Vec<X>> for Xs {
fn from(xs: Vec<X>) -> Self {
let mut ys = Self::default();
for x in xs {
ys.push(x);
}
ys
}
}

impl Xs {
Expand All @@ -17,12 +35,25 @@ impl Xs {
}
}

pub fn add(&mut self, key: &str, value: X) -> Result<()> {
pub fn push(&mut self, value: X) {
loop {
let key = string_random(5);
if !self.map.contains_key(&key) {
self.names.push(key.to_string());
self.map.insert(key.to_string(), value);
break;
}
}
}

pub fn push_kv(&mut self, key: &str, value: X) -> Result<()> {
if !self.map.contains_key(key) {
self.names.push(key.to_string());
self.map.insert(key.to_string(), value);
Ok(())
} else {
anyhow::bail!("Xs already contains key: {:?}", key)
}
self.map.insert(key.to_string(), value);
Ok(())
}

pub fn names(&self) -> &Vec<String> {
Expand Down Expand Up @@ -57,11 +88,26 @@ impl Index<usize> for Xs {
}
}

pub struct XsIter<'a> {
inner: std::vec::IntoIter<&'a X>,
}

impl<'a> Iterator for XsIter<'a> {
type Item = &'a X;

fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
}

impl<'a> IntoIterator for &'a Xs {
type Item = &'a X;
type IntoIter = std::collections::hash_map::Values<'a, String, X>;
type IntoIter = XsIter<'a>;

fn into_iter(self) -> Self::IntoIter {
self.map.values()
let values: Vec<&X> = self.names.iter().map(|x| &self.map[x]).collect();
XsIter {
inner: values.into_iter(),
}
}
}
10 changes: 6 additions & 4 deletions src/models/blip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use ndarray::{s, Array, Axis, IxDyn};
use std::io::Write;
use tokenizers::Tokenizer;

use crate::{Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, X, Y};
use crate::{
Embedding, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, TokenizerStream, Xs, X, Y,
};

#[derive(Debug)]
pub struct Blip {
Expand Down Expand Up @@ -58,7 +60,7 @@ impl Blip {
),
Ops::Nhwc2nchw,
])?;
let ys = self.visual.run(vec![xs_])?;
let ys = self.visual.run(Xs::from(xs_))?;
Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned())))
}

Expand Down Expand Up @@ -108,12 +110,12 @@ impl Blip {
Array::ones(input_ids_nd.shape()).into_dyn();
let input_ids_attn_mask = X::from(input_ids_attn_mask);

let y = self.textual.run(vec![
let y = self.textual.run(Xs::from(vec![
input_ids_nd,
input_ids_attn_mask,
X::from(image_embeds.data().to_owned()),
X::from(image_embeds_attn_mask.to_owned()),
])?; // N, length, vocab_size
]))?; // N, length, vocab_size
let y = y[0].slice(s!(0, -1.., ..));
let logits = y.slice(s!(0, ..)).to_vec();
let token_id = logits_sampler.decode(&logits)?;
Expand Down
6 changes: 3 additions & 3 deletions src/models/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use image::DynamicImage;
use ndarray::Array2;
use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer};

use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, X, Y};
use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};

#[derive(Debug)]
pub struct Clip {
Expand Down Expand Up @@ -69,7 +69,7 @@ impl Clip {
),
Ops::Nhwc2nchw,
])?;
let ys = self.visual.run(vec![xs_])?;
let ys = self.visual.run(Xs::from(xs_))?;
Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned())))
}

Expand All @@ -84,7 +84,7 @@ impl Clip {
.collect();
let xs = Array2::from_shape_vec((texts.len(), self.context_length), xs)?.into_dyn();
let xs = X::from(xs);
let ys = self.textual.run(vec![xs])?;
let ys = self.textual.run(Xs::from(xs))?;
Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned())))
}

Expand Down
2 changes: 1 addition & 1 deletion src/models/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl DB {
Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
let ys = self.engine.run(Xs::from(xs_))?;
self.postprocess(ys, xs)
}

Expand Down
2 changes: 1 addition & 1 deletion src/models/depth_anything.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl DepthAnything {
Ops::Standardize(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
let ys = self.engine.run(Xs::from(xs_))?;
self.postprocess(ys, xs)
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/dinov2.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, X, Y};
use crate::{Embedding, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
use anyhow::Result;
use image::DynamicImage;
// use std::path::PathBuf;
Expand Down Expand Up @@ -63,7 +63,7 @@ impl Dinov2 {
),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(vec![xs_])?;
let ys = self.engine.run(Xs::from(xs_))?;
Ok(Y::default().with_embedding(&Embedding::from(ys[0].to_owned())))
}

Expand Down
2 changes: 0 additions & 2 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ mod db;
mod depth_anything;
mod dinov2;
mod modnet;
mod rtdetr;
mod rtmo;
mod sam;
mod svtr;
Expand All @@ -20,7 +19,6 @@ pub use db::DB;
pub use depth_anything::DepthAnything;
pub use dinov2::Dinov2;
pub use modnet::MODNet;
pub use rtdetr::RTDETR;
pub use rtmo::RTMO;
pub use sam::{SamKind, SamPrompt, SAM};
pub use svtr::SVTR;
Expand Down
2 changes: 1 addition & 1 deletion src/models/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl MODNet {
Ops::Nhwc2nchw,
])?;

let ys = self.engine.run(vec![xs_])?;
let ys = self.engine.run(Xs::from(xs_))?;
self.postprocess(ys, xs)
}

Expand Down
Loading

0 comments on commit 10990c2

Please sign in to comment.