From e0b73e6a5ad3d88b12a5332098d5ee8902603eff Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Thu, 16 Nov 2023 16:01:43 -0800 Subject: [PATCH] [Loaders] Improve web page and sitemap loader usability (#961) --- embedchain/loaders/sitemap.py | 3 ++- embedchain/loaders/web_page.py | 16 +++++++++++----- tests/loaders/test_web_page.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index fa8bbe50ab..707c891dc4 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -3,6 +3,7 @@ import logging import requests +from tqdm import tqdm try: from bs4 import BeautifulSoup @@ -52,7 +53,7 @@ def load_link(link): with concurrent.futures.ThreadPoolExecutor() as executor: future_to_link = {executor.submit(load_link, link): link for link in links} - for future in concurrent.futures.as_completed(future_to_link): + for future in tqdm(concurrent.futures.as_completed(future_to_link), total=len(links)): link = future_to_link[future] try: data = future.result() diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 98109d0716..931031826f 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -17,15 +17,17 @@ @register_deserializable class WebPageLoader(BaseLoader): + # Shared session for all instances + _session = requests.Session() + def load_data(self, url): - """Load data from a web page.""" - response = requests.get(url) + """Load data from a web page using a shared requests session.""" + response = self._session.get(url, timeout=30) + response.raise_for_status() data = response.content content = self._get_clean_content(data, url) - meta_data = { - "url": url, - } + meta_data = {"url": url} doc_id = hashlib.sha256((content + url).encode()).hexdigest() return { @@ -86,3 +88,7 @@ def _get_clean_content(self, html, url) -> str: ) return content + + @classmethod + def close_session(cls): + cls._session.close() diff --git a/tests/loaders/test_web_page.py b/tests/loaders/test_web_page.py index cdaf09447c..3134d4d0c1 100644 --- a/tests/loaders/test_web_page.py +++ b/tests/loaders/test_web_page.py @@ -27,7 +27,7 @@ def test_load_data(web_page_loader): """ - with patch("embedchain.loaders.web_page.requests.get", return_value=mock_response): + with patch("embedchain.loaders.web_page.WebPageLoader._session.get", return_value=mock_response): result = web_page_loader.load_data(page_url) content = web_page_loader._get_clean_content(mock_response.content, page_url)