Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update embedding_model module and Cargo.toml #7

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 83 additions & 70 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ name = "embed_anything"
version = "0.1.5"
edition = "2021"


[lib]
name = "embed_anything"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["extension-module"] }
pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] }
tokio = "1.9"

serde_json = "1.0.112"
reqwest = { version = "0.12.2", features = ["json"] }
futures = "0.3.30"
Expand All @@ -25,4 +25,8 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0", features = ["mkl"] }
anyhow = "1.0.81"
intel-mkl-src = "0.8.1"
candle-pyo3 = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
tokio = {version = "1.37.0", features=["rt-multi-thread"]}
pyo3 = { version = "0.21" }


Binary file removed embed_anything-0.1.1.tar.gz
Binary file not shown.
29 changes: 16 additions & 13 deletions src/embedding_model/bert.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use anyhow::Error as E;
Expand All @@ -12,25 +13,26 @@ pub struct BertEmbeder {
pub model: BertModel,
pub tokenizer: Tokenizer,
}
impl BertEmbeder {
pub fn default() -> anyhow::Result<Self> {

impl Default for BertEmbeder {
fn default() -> Self {
let device = Device::Cpu;
let default_model = "sentence-transformers/all-MiniLM-L12-v2".to_string();
let default_revision = "refs/pr/21".to_string();
let (model_id, _revision) = (default_model, default_revision);
let repo = Repo::model(model_id);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = Api::new().unwrap();
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = api.get("model.safetensors")?;
let config = api.get("config.json").unwrap();
let tokenizer = api.get("tokenizer.json").unwrap();
let weights = api.get("model.safetensors").unwrap();

(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config = std::fs::read_to_string(config_filename).unwrap();
let mut config: Config = serde_json::from_str(&config).unwrap();
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg).unwrap();

let pp = PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
Expand All @@ -39,14 +41,15 @@ impl BertEmbeder {
tokenizer.with_padding(Some(pp));

let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device).unwrap() };

config.hidden_act = HiddenAct::GeluApproximate;

let model = BertModel::load(vb, &config)?;
Ok(BertEmbeder { model, tokenizer })
let model = BertModel::load(vb, &config).unwrap();
BertEmbeder { model, tokenizer }
}

}
impl BertEmbeder {
pub fn tokenize_batch(&self, text_batch: &[String], device: &Device) -> anyhow::Result<Tensor> {
let tokens = self
.tokenizer
Expand Down
23 changes: 11 additions & 12 deletions src/embedding_model/clip.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::any;

use anyhow::Error as E;


Expand All @@ -18,23 +17,23 @@ pub struct ClipEmbeder {
pub tokenizer: Tokenizer,
}

impl ClipEmbeder {
pub fn default() -> anyhow::Result<Self> {
let api = hf_hub::api::sync::Api::new()?;
impl Default for ClipEmbeder {
fn default() -> Self {
let api = hf_hub::api::sync::Api::new().unwrap();
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
let model_file = api.get("model.safetensors")?;
let model_file = api.get("model.safetensors").unwrap();
let config = clip::ClipConfig::vit_base_patch32();
let device = Device::Cpu;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)?
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device).unwrap()
};
let model = clip::ClipModel::new(vb, &config)?;
let tokenizer = Self::get_tokenizer(None)?;
Ok(ClipEmbeder { model, tokenizer })
let model = clip::ClipModel::new(vb, &config).unwrap();
let tokenizer = Self::get_tokenizer(None).unwrap();
ClipEmbeder { model, tokenizer }
}
}

Expand Down Expand Up @@ -123,7 +122,7 @@ impl ClipEmbeder {

fn load_images<T: AsRef<std::path::Path>>(
&self,
paths: &Vec<T>,
paths: &[T],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
Expand All @@ -144,7 +143,7 @@ impl ClipEmbeder {
}

impl EmbedImage for ClipEmbeder{
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&Vec<T>) -> anyhow::Result<Vec<EmbedData>> {
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&[T]) -> anyhow::Result<Vec<EmbedData>> {
let config = clip::ClipConfig::vit_base_patch32();

let images = self.load_images(image_paths, config.image_size).unwrap();
Expand Down
3 changes: 1 addition & 2 deletions src/embedding_model/embed.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use pyo3::prelude::*;
use serde::Deserialize;
use std::any;
use std::collections::HashMap;


Expand Down Expand Up @@ -65,5 +64,5 @@ pub trait Embed {
}

pub trait EmbedImage {
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&Vec<T>) -> anyhow::Result<Vec<EmbedData>>;
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&[T]) -> anyhow::Result<Vec<EmbedData>>;
}
24 changes: 13 additions & 11 deletions src/embedding_model/jina.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use anyhow::Error as E;
use candle_core::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
// use rust_bert::pipelines::sentence_embeddings::{
// SentenceEmbeddingsBuilder, SentenceEmbeddingsModel, SentenceEmbeddingsModelType,
// };
use super::embed::{Embed, EmbedData};
use candle_transformers::models::jina_bert::{BertModel, Config};
use hf_hub::{Repo, RepoType};
Expand All @@ -12,30 +9,35 @@ pub struct JinaEmbeder {
pub model: BertModel,
pub tokenizer: Tokenizer,
}
impl JinaEmbeder {
pub fn default() -> anyhow::Result<Self> {
let api = hf_hub::api::sync::Api::new()?;

impl Default for JinaEmbeder {
fn default() -> Self {
let api = hf_hub::api::sync::Api::new().unwrap();
let model_file = api
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?;
.get("model.safetensors")
.unwrap();
let config = Config::v2_base();

let device = Device::Cpu;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)?
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device).unwrap()
};
let model = BertModel::new(vb, &config)?;
let mut tokenizer = Self::get_tokenizer(None)?;
let model = BertModel::new(vb, &config).unwrap();
let mut tokenizer = Self::get_tokenizer(None).unwrap();
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
Ok(JinaEmbeder { model, tokenizer })
JinaEmbeder { model, tokenizer }
}
}

impl JinaEmbeder {

pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
Expand Down
43 changes: 21 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

pub mod embedding_model;
pub mod file_embed;
pub mod parser;
pub mod pdf_processor;

use std::path::PathBuf;

use embedding_model::{
clip::ClipEmbeder,
embed::{Embed, EmbedData, EmbedImage, Embeder},
};
use embedding_model::embed::{EmbedData, EmbedImage, Embeder};
use file_embed::FileEmbeder;
use parser::FileParser;
use pyo3::{exceptions::PyValueError, prelude::*};
use rayon::prelude::*;
use tokio::runtime::Builder;

#[pyfunction]
pub fn embed_query(query: Vec<String>, embeder: &str) -> PyResult<Vec<EmbedData>> {
let embedding_model = match embeder {
"OpenAI" => Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()),
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default().unwrap()),
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default().unwrap()),
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default().unwrap()),
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default()),
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
_ => {
return Err(PyValueError::new_err(
"Invalid embedding model. Choose between OpenAI and AllMiniLmL12V2.",
))
}
};
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();

let embeddings = tokio::runtime::Runtime::new()
.unwrap()
.block_on(embedding_model.embed(&query))
.unwrap();
let embeddings = runtime.block_on(embedding_model.embed(&query)).unwrap();
Ok(embeddings)
}

Expand All @@ -40,9 +39,9 @@ pub fn embed_query(query: Vec<String>, embeder: &str) -> PyResult<Vec<EmbedData>
pub fn embed_file(file_name: &str, embeder: &str) -> PyResult<Vec<EmbedData>> {
let embedding_model = match embeder {
"OpenAI" => Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()),
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default().unwrap()),
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default().unwrap()),
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default().unwrap()),
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default()),
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
_ => {
return Err(PyValueError::new_err(
"Invalid embedding model. Choose between OpenAI and AllMiniLmL12V2.",
Expand All @@ -53,8 +52,8 @@ pub fn embed_file(file_name: &str, embeder: &str) -> PyResult<Vec<EmbedData>> {
let mut file_embeder = FileEmbeder::new(file_name.to_string());
let text = file_embeder.extract_text().unwrap();
file_embeder.split_into_chunks(&text, 100);
tokio::runtime::Runtime::new()
.unwrap()
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
runtime
.block_on(file_embeder.embed(&embedding_model))
.unwrap();
Ok(file_embeder.embeddings)
Expand All @@ -70,17 +69,17 @@ pub fn embed_directory(directory: PathBuf, embeder: &str) -> PyResult<Vec<EmbedD
.unwrap(),
"Jina" => emb(
directory,
Embeder::Jina(embedding_model::jina::JinaEmbeder::default().unwrap()),
Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
)
.unwrap(),
"Bert" => emb(
directory,
Embeder::Bert(embedding_model::bert::BertEmbeder::default().unwrap()),
Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
)
.unwrap(),
"Clip" => emb_image(
directory,
embedding_model::clip::ClipEmbeder::default().unwrap(),
embedding_model::clip::ClipEmbeder::default(),
)
.unwrap(),

Expand All @@ -96,7 +95,7 @@ pub fn embed_directory(directory: PathBuf, embeder: &str) -> PyResult<Vec<EmbedD

/// A Python module implemented in Rust.
#[pymodule]
fn embed_anything(_py: Python, m: &PyModule) -> PyResult<()> {
fn embed_anything(m: &Bound<'_,PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(embed_file, m)?)?;
m.add_function(wrap_pyfunction!(embed_directory, m)?)?;
m.add_function(wrap_pyfunction!(embed_query, m)?)?;
Expand All @@ -115,8 +114,8 @@ fn emb(directory: PathBuf, embedding_model: Embeder) -> PyResult<Vec<EmbedData>>
let mut file_embeder = FileEmbeder::new(file.to_string());
let text = file_embeder.extract_text().unwrap();
file_embeder.split_into_chunks(&text, 100);
tokio::runtime::Runtime::new()
.unwrap()
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
runtime
.block_on(file_embeder.embed(&embedding_model))
.unwrap();
file_embeder.embeddings
Expand Down
13 changes: 7 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import os
import time
import numpy as np
os.add_dll_directory(r'D:\libtorch\lib')
# os.add_dll_directory(r'D:\test')
from embed_anything import EmbedData
import embed_anything
from PIL import Image
# start = time.time()
import time

# data:list[EmbedData] = embed_anything.embed_file("test_files/TUe_SOP_AI_2.pdf", embeder= "Bert")

# embeddings = np.array([data.embedding for data in data])

# end = time.time()

# print(embeddings)
# print("Time taken: ", end-start)


start = time.time()
data:list[EmbedData] = embed_anything.embed_directory("test_files", embeder= "Clip")

embeddings = np.array([data.embedding for data in data])
Expand All @@ -29,5 +28,7 @@

max_index = np.argmax(similarities)

Image.open(data[max_index].text).show()
# Image.open(data[max_index].text).show()
print(data[max_index].text)
end = time.time()
print("Time taken: ", end-start)
Binary file not shown.
Binary file removed test_files/TUe_SOP_AI_2.pdf
Binary file not shown.
Binary file removed test_files/wa4_ethics_written_assignment.pdf
Binary file not shown.
Loading