diff --git a/embedchain/loaders/github.py b/embedchain/loaders/github.py index a1d990d609..8800b6c3cd 100644 --- a/embedchain/loaders/github.py +++ b/embedchain/loaders/github.py @@ -3,6 +3,8 @@ import logging import os +from tqdm import tqdm + from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.json import JSONLoader from embedchain.loaders.mdx import MdxLoader @@ -53,14 +55,24 @@ def _load_file(file_path: str): return data.get("data", []) + def _is_file_empty(file_path): + return os.path.getsize(file_path) == 0 + + def _is_whitelisted(file_path): + whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"] + _, file_extension = os.path.splitext(file_path) + return file_extension[1:] in whitelisted_extensions + def _add_repo_files(repo_path: str): with concurrent.futures.ThreadPoolExecutor() as executor: future_to_file = { executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename) for root, _, files in os.walk(repo_path) for filename in files - } # noqa: E501 - for future in concurrent.futures.as_completed(future_to_file): + if _is_whitelisted(os.path.join(root, filename)) + and not _is_file_empty(os.path.join(root, filename)) # noqa:E501 + } + for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)): file = future_to_file[future] try: results = future.result() diff --git a/embedchain/utils.py b/embedchain/utils.py index 2ae20e3991..21a55e79bd 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -216,6 +216,10 @@ def is_openapi_yaml(yaml_content): logging.debug(f"Source of `{formatted_source}` detected as `csv`.") return DataType.CSV + if url.path.endswith(".mdx") or url.path.endswith(".md"): + logging.debug(f"Source of `{formatted_source}` detected as `mdx`.") + return DataType.MDX + if url.path.endswith(".docx"): logging.debug(f"Source of `{formatted_source}` detected as `docx`.") return DataType.DOCX @@ -292,6 +296,10 @@ def is_openapi_yaml(yaml_content): logging.debug(f"Source of `{formatted_source}` detected as `xml`.") return DataType.XML + if source.endswith(".mdx") or source.endswith(".md"): + logging.debug(f"Source of `{formatted_source}` detected as `mdx`.") + return DataType.MDX + if source.endswith(".yaml"): with open(source, "r") as file: yaml_content = yaml.safe_load(file)