Skip to content

Commit

Permalink
Add nltk dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
aravind10x committed Jan 7, 2025
1 parent faec7ca commit 9f46e0e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/ragbuilder/config/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def get_class():
ParserType.UNSTRUCTURED: {
"required": [],
"optional": [],
"packages": [_PkgSpec("unstructured")]
"packages": [_PkgSpec("unstructured"), _PkgSpec("nltk")]
},
ParserType.PYMUPDF: {
"required": [],
Expand Down
24 changes: 23 additions & 1 deletion src/ragbuilder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
UnstructuredFileLoader
)
from langchain_core.documents import Document
from ragbuilder.config.components import COMPONENT_ENV_REQUIREMENTS
from ragbuilder.config.components import COMPONENT_ENV_REQUIREMENTS, ParserType
from ragbuilder.config.data_ingest import DataIngestOptionsConfig
from ragbuilder.config.retriever import RetrievalOptionsConfig
import json
Expand Down Expand Up @@ -177,6 +177,28 @@ def validate_component_env(component_value: str) -> Tuple[List[str], List[str]]:
missing_env.extend([var for var in requirements["required"] if not os.getenv(var)])
missing_packages.extend([pkg_name for pkg in requirements.get("packages", [])
if (pkg_name := pkg.validate())])

if component_value == ParserType.UNSTRUCTURED:
try:
import nltk
nltk_resources = ['punkt', 'punkt_tab', 'averaged_perceptron_tagger']
for resource in nltk_resources:
try:
nltk.data.find(f'tokenizers/{resource}')
except LookupError:
try:
logger.info(f"Downloading required NLTK data '{resource}' for unstructured parser...")
nltk.download(resource, quiet=True)
except Exception as e:
logger.warning(f"Failed to download NLTK data '{resource}': {str(e)}")
missing_packages.append(f"nltk[{resource}]")

except ImportError:
missing_packages.append("nltk")
except Exception as e:
logger.warning(f"Failed to validate/download NLTK data: {str(e)}")
missing_packages.extend([f"nltk[{resource}]" for resource in nltk_resources])

return missing_env, missing_packages

def simplify_model_config(obj: Any) -> Dict[str, Any]:
Expand Down

0 comments on commit 9f46e0e

Please sign in to comment.