diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index 29638b0c82..ba3ec0a932 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -1,6 +1,8 @@ import concurrent.futures import hashlib import logging +import os +from urllib.parse import urlparse import requests @@ -21,23 +23,34 @@ @register_deserializable class SitemapLoader(BaseLoader): """ - This method takes a sitemap URL as input and retrieves + This method takes a sitemap URL or local file path as input and retrieves all the URLs to use the WebPageLoader to load content of each page. """ - def load_data(self, sitemap_url): + def load_data(self, sitemap_source): output = [] web_page_loader = WebPageLoader() - response = requests.get(sitemap_url) - response.raise_for_status() - soup = BeautifulSoup(response.text, "xml") + if urlparse(sitemap_source).scheme in ("http", "https"): + try: + response = requests.get(sitemap_source) + response.raise_for_status() + soup = BeautifulSoup(response.text, "xml") + except requests.RequestException as e: + logging.error(f"Error fetching sitemap from URL: {e}") + return {"doc_id": "", "data": []} + elif os.path.isfile(sitemap_source): + with open(sitemap_source, "r") as file: + soup = BeautifulSoup(file, "xml") + else: + raise ValueError("Invalid sitemap source. Please provide a valid URL or local file path.") + links = [link.text for link in soup.find_all("loc") if link.parent.name == "url"] if len(links) == 0: links = [link.text for link in soup.find_all("loc")] - doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest() + doc_id = hashlib.sha256((" ".join(links) + sitemap_source).encode()).hexdigest() def load_link(link): try: