Skip to content

Commit

Permalink
Merge pull request #13 from akshayballal95/main
Browse files Browse the repository at this point in the history
Remove Asunc
  • Loading branch information
akshayballal95 authored May 4, 2024
2 parents 78a5b31 + b65b979 commit 335a948
Show file tree
Hide file tree
Showing 13 changed files with 182 additions and 154 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 37 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,59 @@ version = "0.1.14"
edition = "2021"

[dependencies]
# Data Serialization
serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.112"

# HTTP Client
reqwest = { version = "0.12.2", features = ["json", "blocking"] }
serde = {version = "1.0.196", features = ["derive"]}
pdf-extract = "0.7.4"

# Filesystem
walkdir = "2.4.0"

# Regular Expressions
regex = "1.10.3"

# Parallelism
rayon = "1.8.1"

# Image Processing
image = "0.25.1"
hf-hub = "0.3.2"

# Natural Language Processing
tokenizers = "0.15.2"

# PDF Processing
pdf-extract = "0.7.4"

# Hugging Face Libraries
hf-hub = "0.3.2"
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }

# Error Handling
anyhow = "1.0.81"
tokio = {version = "1.37.0", features=["rt-multi-thread", "macros"]}

# Asynchronous Programming
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }

# Python Interoperability
pyo3 = { version = "0.21" }
intel-mkl-src = {version = "0.8.1", optional = true }

# Optional Dependency
intel-mkl-src = { version = "0.8.1", optional = true }

# Markdown Processing
markdown-parser = "0.1.2"
markdown_to_text = "1.0.0"

# Web Scraping
scraper = "0.19.0"

# Text Processing
text-cleaner = "0.1.0"
url = "2.5.0"

[dev-dependencies]
tempdir = "0.3.7"
Expand Down
2 changes: 1 addition & 1 deletion examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant};

fn main() {
let now = Instant::now();
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["md".to_string()])).unwrap();
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["pdf".to_string()])).unwrap();
println!("{:?}", out);
let elapsed_time = now.elapsed();
println!("Elapsed Time: {}", elapsed_time.as_secs_f32());
Expand Down
13 changes: 8 additions & 5 deletions examples/web_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ use candle_core::Tensor;

#[tokio::main]
async fn main() {
let url = "https://en.wikipedia.org/wiki/Long_short-term_memory";

let website_processor = website_processor::WebsiteProcesor;
let start_time = std::time::Instant::now();
let url = "https://www.scrapingbee.com/blog/web-scraping-rust/";

let website_processor = website_processor::WebsiteProcessor;
let webpage = website_processor.process_website(url).await.unwrap();
let embeder = embed_anything::embedding_model::bert::BertEmbeder::default();
let embed_data = webpage.embed_webpage(&embeder).await.unwrap();
let embed_data = webpage.embed_webpage(&embeder).unwrap();
let embeddings: Vec<Vec<f32>> = embed_data.iter().map(|data| data.embedding.clone()).collect();

let embeddings = Tensor::from_vec(
Expand All @@ -17,8 +19,8 @@ async fn main() {
&candle_core::Device::Cpu,
).unwrap();

let query = vec!["how to use lstm for nlp".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).await.unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();
let query = vec!["Rust for web scraping".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();

let query_embedding_tensor = Tensor::from_vec(
query_embedding.clone(),
Expand All @@ -40,5 +42,6 @@ async fn main() {
let data = &embed_data[max_similarity_index];

println!("{:?}", data);
println!("Time taken: {:?}", start_time.elapsed());

}
6 changes: 3 additions & 3 deletions src/embedding_model/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl BertEmbeder {
Ok(Tensor::stack(&token_ids, 0)?)
}

pub async fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
pub fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let token_type_ids = token_ids.zeros_like().unwrap();
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
Expand All @@ -89,7 +89,7 @@ impl Embed for BertEmbeder {
fn embed(
&self,
text_batch: &[String],metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -99,7 +99,7 @@ impl TextEmbed for BertEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/embedding_model/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl EmbedImage for ClipEmbeder {
}

impl Embed for ClipEmbeder {
async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
let (input_ids, _vec_seq) = ClipEmbeder::tokenize_sequences(
Some(text_batch.to_vec()),
&self.tokenizer,
Expand Down
14 changes: 7 additions & 7 deletions src/embedding_model/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ pub enum Embeder {
}

impl Embeder {
pub async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
pub fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
match self {
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata).await,
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata),
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
}
}
}
Expand All @@ -76,12 +76,12 @@ pub trait Embed {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
) ->Result<Vec<EmbedData>, anyhow::Error>;

}

pub trait TextEmbed {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error>;
}

pub trait EmbedImage {
Expand Down
6 changes: 3 additions & 3 deletions src/embedding_model/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl JinaEmbeder {
Ok(Tensor::stack(&token_ids, 0)?)
}

async fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let embeddings = self.model.forward(&token_ids).unwrap();

Expand All @@ -97,7 +97,7 @@ impl Embed for JinaEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -107,7 +107,7 @@ impl TextEmbed for JinaEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand Down
58 changes: 37 additions & 21 deletions src/embedding_model/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ impl Default for OpenAIEmbeder {
}

impl Embed for OpenAIEmbeder {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -32,7 +36,7 @@ impl TextEmbed for OpenAIEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -47,28 +51,41 @@ impl OpenAIEmbeder {
}
}

async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> Result<Vec<EmbedData>, anyhow::Error> {
let client = Client::new();

let response = client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"input": text_batch,
"model": "text-embedding-3-small",
}))
.send()
.await?;

let data = response.json::<EmbedResponse>().await?;
println!("{:?}", data.usage);
let runtime = tokio::runtime::Builder::new_current_thread().enable_io()
.build()
.unwrap();

let data = runtime.block_on(async move {
let response = client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"input": text_batch,
"model": "text-embedding-3-small",
}))
.send()
.await
.unwrap();

let data = response.json::<EmbedResponse>().await.unwrap();
println!("{:?}", data.usage);
data
});

let emb_data = data
.data
.iter()
.zip(text_batch)
.map(move |(data, text)| EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone()))
.map(move |(data, text)| {
EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone())
})
.collect::<Vec<_>>();

Ok(emb_data)
Expand All @@ -79,15 +96,14 @@ impl OpenAIEmbeder {
mod tests {
use super::*;

#[tokio::test]
async fn test_openai_embed() {
fn test_openai_embed() {
let openai = OpenAIEmbeder::default();
let text_batch = vec![
"Once upon a time".to_string(),
"The quick brown fox jumps over the lazy dog".to_string(),
];

let embeddings = openai.embed(&text_batch, None).await.unwrap();
let embeddings = openai.embed(&text_batch, None).unwrap();
assert_eq!(embeddings.len(), 2);
}
}
6 changes: 3 additions & 3 deletions src/file_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ impl FileEmbeder {

}

pub async fn embed(&mut self, embeder: &Embeder, metadata: Option<HashMap< String, String>>) -> Result<(), reqwest::Error> {
self.embeddings = embeder.embed(&self.chunks, metadata).await?;
pub fn embed(&mut self, embeder: &Embeder, metadata: Option<HashMap< String, String>>) -> Result<(), anyhow::Error> {
self.embeddings = embeder.embed(&self.chunks, metadata)?;
Ok(())
}

Expand Down Expand Up @@ -91,7 +91,7 @@ mod tests {
let embeder = Embeder::Bert(BertEmbeder::default());
let mut file_embeder = FileEmbeder::new(file_path.to_string_lossy().to_string());
file_embeder.split_into_chunks(&text, 100);
file_embeder.embed(&embeder, None).await.unwrap();
file_embeder.embed(&embeder, None).unwrap();
assert_eq!(file_embeder.chunks.len(), 5);
assert_eq!(file_embeder.embeddings.len(), 5);
}
Expand Down
Loading

0 comments on commit 335a948

Please sign in to comment.