Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Asunc #13

Merged
merged 2 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 37 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,59 @@ version = "0.1.14"
edition = "2021"

[dependencies]
# Data Serialization
serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.112"

# HTTP Client
reqwest = { version = "0.12.2", features = ["json", "blocking"] }
serde = {version = "1.0.196", features = ["derive"]}
pdf-extract = "0.7.4"

# Filesystem
walkdir = "2.4.0"

# Regular Expressions
regex = "1.10.3"

# Parallelism
rayon = "1.8.1"

# Image Processing
image = "0.25.1"
hf-hub = "0.3.2"

# Natural Language Processing
tokenizers = "0.15.2"

# PDF Processing
pdf-extract = "0.7.4"

# Hugging Face Libraries
hf-hub = "0.3.2"
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }

# Error Handling
anyhow = "1.0.81"
tokio = {version = "1.37.0", features=["rt-multi-thread", "macros"]}

# Asynchronous Programming
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }

# Python Interoperability
pyo3 = { version = "0.21" }
intel-mkl-src = {version = "0.8.1", optional = true }

# Optional Dependency
intel-mkl-src = { version = "0.8.1", optional = true }

# Markdown Processing
markdown-parser = "0.1.2"
markdown_to_text = "1.0.0"

# Web Scraping
scraper = "0.19.0"

# Text Processing
text-cleaner = "0.1.0"
url = "2.5.0"

[dev-dependencies]
tempdir = "0.3.7"
Expand Down
2 changes: 1 addition & 1 deletion examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant};

fn main() {
let now = Instant::now();
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["md".to_string()])).unwrap();
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["pdf".to_string()])).unwrap();
println!("{:?}", out);
let elapsed_time = now.elapsed();
println!("Elapsed Time: {}", elapsed_time.as_secs_f32());
Expand Down
13 changes: 8 additions & 5 deletions examples/web_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ use candle_core::Tensor;

#[tokio::main]
async fn main() {
let url = "https://en.wikipedia.org/wiki/Long_short-term_memory";

let website_processor = website_processor::WebsiteProcesor;
let start_time = std::time::Instant::now();
let url = "https://www.scrapingbee.com/blog/web-scraping-rust/";

let website_processor = website_processor::WebsiteProcessor;
let webpage = website_processor.process_website(url).await.unwrap();
let embeder = embed_anything::embedding_model::bert::BertEmbeder::default();
let embed_data = webpage.embed_webpage(&embeder).await.unwrap();
let embed_data = webpage.embed_webpage(&embeder).unwrap();
let embeddings: Vec<Vec<f32>> = embed_data.iter().map(|data| data.embedding.clone()).collect();

let embeddings = Tensor::from_vec(
Expand All @@ -17,8 +19,8 @@ async fn main() {
&candle_core::Device::Cpu,
).unwrap();

let query = vec!["how to use lstm for nlp".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).await.unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();
let query = vec!["Rust for web scraping".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();

let query_embedding_tensor = Tensor::from_vec(
query_embedding.clone(),
Expand All @@ -40,5 +42,6 @@ async fn main() {
let data = &embed_data[max_similarity_index];

println!("{:?}", data);
println!("Time taken: {:?}", start_time.elapsed());

}
6 changes: 3 additions & 3 deletions src/embedding_model/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl BertEmbeder {
Ok(Tensor::stack(&token_ids, 0)?)
}

pub async fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
pub fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let token_type_ids = token_ids.zeros_like().unwrap();
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
Expand All @@ -89,7 +89,7 @@ impl Embed for BertEmbeder {
fn embed(
&self,
text_batch: &[String],metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -99,7 +99,7 @@ impl TextEmbed for BertEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String,String>>
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/embedding_model/clip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl EmbedImage for ClipEmbeder {
}

impl Embed for ClipEmbeder {
async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
let (input_ids, _vec_seq) = ClipEmbeder::tokenize_sequences(
Some(text_batch.to_vec()),
&self.tokenizer,
Expand Down
14 changes: 7 additions & 7 deletions src/embedding_model/embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ pub enum Embeder {
}

impl Embeder {
pub async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
pub fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
match self {
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata).await,
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata),
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
}
}
}
Expand All @@ -76,12 +76,12 @@ pub trait Embed {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
) ->Result<Vec<EmbedData>, anyhow::Error>;

}

pub trait TextEmbed {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error>;
}

pub trait EmbedImage {
Expand Down
6 changes: 3 additions & 3 deletions src/embedding_model/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl JinaEmbeder {
Ok(Tensor::stack(&token_ids, 0)?)
}

async fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
let embeddings = self.model.forward(&token_ids).unwrap();

Expand All @@ -97,7 +97,7 @@ impl Embed for JinaEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -107,7 +107,7 @@ impl TextEmbed for JinaEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand Down
58 changes: 37 additions & 21 deletions src/embedding_model/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ impl Default for OpenAIEmbeder {
}

impl Embed for OpenAIEmbeder {
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -32,7 +36,7 @@ impl TextEmbed for OpenAIEmbeder {
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
) -> Result<Vec<EmbedData>, anyhow::Error> {
self.embed(text_batch, metadata)
}
}
Expand All @@ -47,28 +51,41 @@ impl OpenAIEmbeder {
}
}

async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
fn embed(
&self,
text_batch: &[String],
metadata: Option<HashMap<String, String>>,
) -> Result<Vec<EmbedData>, anyhow::Error> {
let client = Client::new();

let response = client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"input": text_batch,
"model": "text-embedding-3-small",
}))
.send()
.await?;

let data = response.json::<EmbedResponse>().await?;
println!("{:?}", data.usage);
let runtime = tokio::runtime::Builder::new_current_thread().enable_io()
.build()
.unwrap();

let data = runtime.block_on(async move {
let response = client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"input": text_batch,
"model": "text-embedding-3-small",
}))
.send()
.await
.unwrap();

let data = response.json::<EmbedResponse>().await.unwrap();
println!("{:?}", data.usage);
data
});

let emb_data = data
.data
.iter()
.zip(text_batch)
.map(move |(data, text)| EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone()))
.map(move |(data, text)| {
EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone())
})
.collect::<Vec<_>>();

Ok(emb_data)
Expand All @@ -79,15 +96,14 @@ impl OpenAIEmbeder {
mod tests {
use super::*;

#[tokio::test]
async fn test_openai_embed() {
fn test_openai_embed() {
let openai = OpenAIEmbeder::default();
let text_batch = vec![
"Once upon a time".to_string(),
"The quick brown fox jumps over the lazy dog".to_string(),
];

let embeddings = openai.embed(&text_batch, None).await.unwrap();
let embeddings = openai.embed(&text_batch, None).unwrap();
assert_eq!(embeddings.len(), 2);
}
}
6 changes: 3 additions & 3 deletions src/file_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ impl FileEmbeder {

}

pub async fn embed(&mut self, embeder: &Embeder, metadata: Option<HashMap< String, String>>) -> Result<(), reqwest::Error> {
self.embeddings = embeder.embed(&self.chunks, metadata).await?;
pub fn embed(&mut self, embeder: &Embeder, metadata: Option<HashMap< String, String>>) -> Result<(), anyhow::Error> {
self.embeddings = embeder.embed(&self.chunks, metadata)?;
Ok(())
}

Expand Down Expand Up @@ -91,7 +91,7 @@ mod tests {
let embeder = Embeder::Bert(BertEmbeder::default());
let mut file_embeder = FileEmbeder::new(file_path.to_string_lossy().to_string());
file_embeder.split_into_chunks(&text, 100);
file_embeder.embed(&embeder, None).await.unwrap();
file_embeder.embed(&embeder, None).unwrap();
assert_eq!(file_embeder.chunks.len(), 5);
assert_eq!(file_embeder.embeddings.len(), 5);
}
Expand Down
Loading
Loading