Skip to content

Commit

Permalink
Add url dependency and fix typo in WebsiteProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayballal95 committed Apr 22, 2024
1 parent 8f97ff8 commit b65b979
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 86 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ scraper = "0.19.0"

# Text Processing
text-cleaner = "0.1.0"
url = "2.5.0"

[dev-dependencies]
tempdir = "0.3.7"
Expand Down
9 changes: 6 additions & 3 deletions examples/web_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use candle_core::Tensor;

#[tokio::main]
async fn main() {
let url = "https://en.wikipedia.org/wiki/Long_short-term_memory";

let website_processor = website_processor::WebsiteProcesor;
let start_time = std::time::Instant::now();
let url = "https://www.scrapingbee.com/blog/web-scraping-rust/";

let website_processor = website_processor::WebsiteProcessor;
let webpage = website_processor.process_website(url).await.unwrap();
let embeder = embed_anything::embedding_model::bert::BertEmbeder::default();
let embed_data = webpage.embed_webpage(&embeder).unwrap();
Expand All @@ -17,7 +19,7 @@ async fn main() {
&candle_core::Device::Cpu,
).unwrap();

let query = vec!["how to use lstm for nlp".to_string()];
let query = vec!["Rust for web scraping".to_string()];
let query_embedding: Vec<f32> = embeder.embed(&query, None).unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();

let query_embedding_tensor = Tensor::from_vec(
Expand All @@ -40,5 +42,6 @@ async fn main() {
let data = &embed_data[max_similarity_index];

println!("{:?}", data);
println!("Time taken: {:?}", start_time.elapsed());

}
140 changes: 58 additions & 82 deletions src/file_processor/website_processor.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::{HashMap, HashSet};
use std::
collections::{HashMap, HashSet};

use anyhow::{Error, Ok};
use regex::Regex;
use scraper::Selector;
use anyhow:: Result;
use scraper::{Html, Selector};
use serde_json::json;
use text_cleaner::clean::Clean;
use url::Url;

use crate::{
embedding_model::embed::{EmbedData, TextEmbed},
Expand All @@ -22,45 +22,36 @@ pub struct WebPage {
}

impl WebPage {
pub fn embed_webpage<T: TextEmbed>(&self, embeder: &T) -> Result<Vec<EmbedData>, Error>{
pub fn embed_webpage<T: TextEmbed>(&self, embeder: &T) -> Result<Vec<EmbedData>> {
let mut embed_data = Vec::new();
let paragraph_embeddings = if let Some(paragraphs) = &self.paragraphs {
self.embed_tag::<T>("p", paragraphs.to_vec(), &embeder).unwrap_or(Vec::new())
} else {
Vec::new()
};

let header_embeddings = if let Some(headers) = &self.headers {
self.embed_tag::<T>("h1", headers.to_vec(), &embeder).unwrap_or(Vec::new())
} else {
Vec::new()
};
if let Some(paragraphs) = &self.paragraphs {
embed_data.extend(self.embed_tag("p", paragraphs, embeder)?);
}

let code_embeddings = if let Some(codes) = &self.codes {
self.embed_tag::<T>("code", codes.to_vec(), &embeder).unwrap_or(Vec::new())
} else {
Vec::new()
};
if let Some(headers) = &self.headers {
embed_data.extend(self.embed_tag("h1", headers, embeder)?);
}

if let Some(codes) = &self.codes {
embed_data.extend(self.embed_tag("code", codes, embeder)?);
}

embed_data.extend(paragraph_embeddings);
embed_data.extend(header_embeddings);
embed_data.extend(code_embeddings);
Ok(embed_data)
}

pub fn embed_tag<T: TextEmbed>(&self,tag: &str, tag_content: Vec<String>, embeder: &T) -> Result<Vec<EmbedData>, Error> {
pub fn embed_tag<T: TextEmbed>(&self, tag: &str, tag_content: &[String], embeder: &T) -> Result<Vec<EmbedData>> {
let mut embed_data = Vec::new();

for content in tag_content {
let mut file_embeder = FileEmbeder::new(self.url.to_string());

let chunks = match file_embeder.split_into_chunks(&content, 1000) {
let chunks = match file_embeder.split_into_chunks(content, 1000) {
Some(chunks) => chunks,
None => continue,
};

match chunks.len() {
0 => continue,
_ => (),
if chunks.is_empty() {
continue;
}

let tag_type = match tag {
Expand All @@ -78,44 +69,34 @@ impl WebPage {
"full_text": content,
});

let metadata_hashmap: HashMap<String, String> =
serde_json::from_value(metadata).unwrap();


let embeddings = embeder
.embed(&chunks, Some(metadata_hashmap))

.unwrap_or(Vec::new());
for embedding in embeddings {
embed_data.push(embedding);

}
let metadata_hashmap: HashMap<String, String> = serde_json::from_value(metadata)?;

let embeddings = embeder.embed(&chunks, Some(metadata_hashmap))?;
embed_data.extend(embeddings);
}

Ok(embed_data)
}
}

/// A struct for processing websites.
pub struct WebsiteProcesor;
pub struct WebsiteProcessor;


impl WebsiteProcesor {
impl WebsiteProcessor {
pub fn new() -> Self {
Self {}
}

pub async fn process_website(&self, website: &str) -> Result<WebPage, Error> {
pub async fn process_website(&self, website: &str) -> Result<WebPage> {
let response = reqwest::get(website).await?.text().await?;
let document = scraper::Html::parse_document(&response);
let document = Html::parse_document(&response);
let headers = self.get_text_from_tag("h1,h2,h3", &document)?;
let paragraphs = self.get_text_from_tag("p", &document)?;
let codes = self.get_text_from_tag("code", &document)?;
let links = self.extract_links(website, &document)?;
let binding = self.get_text_from_tag("h1", &document)?;
let title = binding.first();
let title = self.get_title(&document)?;
let web_page = WebPage {
url: website.to_string(),
title: title.map(|s| s.to_string()),
title,
headers: Some(headers),
paragraphs: Some(paragraphs),
codes: Some(codes),
Expand All @@ -125,42 +106,37 @@ impl WebsiteProcesor {
Ok(web_page)
}

pub fn get_text_from_tag(
&self,
tag: &str,
document: &scraper::Html,
) -> Result<Vec<String>, Error> {
let selector = Selector::parse(tag).map_err(|e| Error::msg(e.to_string()))?;
fn get_text_from_tag(&self, tag: &str, document: &Html) -> Result<Vec<String>, anyhow::Error> {
let selector = Selector::parse(tag).unwrap();
Ok(document
.select(&selector)
.map(|element| element.text().collect::<String>().trim())
.collect::<Vec<_>>())
.map(|element| element.text().collect::<String>().trim().to_string())
.collect())
}

pub fn extract_links(
&self,
website: &str,
document: &scraper::Html,
) -> Result<HashSet<String>, Error> {
fn extract_links(&self, website: &str, document: &Html) -> Result<HashSet<String>> {
let mut links = HashSet::new();
let _ = document
.select(&Selector::parse("a").unwrap())
.map(|element| {
let link = element.value().attr("href").unwrap_or_default().to_string();
let regex: Regex = Regex::new(
r"^((https?|ftp|smtp):\/\/)?(www.)?[a-z0-9]+\.[a-z]+(\/[a-zA-Z0-9#]+\/?)*$",
)
.unwrap();
// Check if the link is a valid URL using regex. If not append the website URL to the beginning of the link.
if !regex.is_match(&link) {
links.insert(format!("{}{}", website, link));
} else {
links.insert(link);
}
});
let base_url = Url::parse(website)?;

for element in document.select(&Selector::parse("a").unwrap()) {
if let Some(href) = element.value().attr("href") {
let mut link_url = base_url.join(href)?;
// Normalize URLs, remove fragments and ensure they are absolute.
link_url.set_fragment(None);
links.insert(link_url.to_string());
}
}

Ok(links)
}

fn get_title(&self, document: &Html) -> Result<Option<String>> {
if let Some(title_element) = document.select(&Selector::parse("title").unwrap()).next() {
Ok(Some(title_element.text().collect::<String>()))
} else {
Ok(None)
}
}
}

#[cfg(test)]
Expand All @@ -169,9 +145,9 @@ mod tests {

#[tokio::test]
async fn test_process_website() {
let website_processor = WebsiteProcesor;
let website_processor = WebsiteProcessor::new();
let website = "https://www.scrapingbee.com/blog/web-scraping-rust/";
let result = website_processor.process_website(website);
assert!(result.await.is_ok());
let result = website_processor.process_website(website).await;
assert!(result.is_ok());
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ pub fn embed_directory(

#[pyfunction]
pub fn emb_webpage(url: String, embeder: &str) -> PyResult<Vec<EmbedData>> {
let website_processor = file_processor::website_processor::WebsiteProcesor::new();
let website_processor = file_processor::website_processor::WebsiteProcessor::new();
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
let webpage = runtime
.block_on(website_processor.process_website(url.as_ref()))
Expand Down
Binary file added test_files/attention.pdf
Binary file not shown.

0 comments on commit b65b979

Please sign in to comment.