Skip to content

Commit

Permalink
Demo notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
aravind10x committed Dec 31, 2024
1 parent 1d79b77 commit faec7ca
Show file tree
Hide file tree
Showing 9 changed files with 1,566 additions and 14,777 deletions.
15,995 changes: 1,316 additions & 14,679 deletions ragbuilder_sdk_demo.ipynb

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions src/ragbuilder/config/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,66 @@ def get_class():
_PkgSpec("ollama")
]
},

# LLMs
LLMType.AZURE_OPENAI: {
"required": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"],
"optional": ["AZURE_DEPLOYMENT_NAME"],
"packages": [
_PkgSpec("langchain-openai"),
_PkgSpec("openai"),
_PkgSpec("tiktoken")
]
},
LLMType.OPENAI: {
"required": ["OPENAI_API_KEY"],
"optional": [],
"packages": [
_PkgSpec("langchain-openai"),
_PkgSpec("openai"),
_PkgSpec("tiktoken")
]
},
LLMType.HUGGINGFACE: {
"required": [],
"optional": [],
"packages": [
_PkgSpec("langchain-huggingface"),
_PkgSpec("sentence-transformers"),
_PkgSpec("torch")
]
},
LLMType.OLLAMA: {
"required": [],
"optional": [],
"packages": [
_PkgSpec("langchain-ollama"),
_PkgSpec("ollama")
]
},
LLMType.COHERE: {
"required": ["COHERE_API_KEY"],
"optional": [],
"packages": [_PkgSpec("cohere")]
},
LLMType.VERTEXAI: {
"required": ["GOOGLE_APPLICATION_CREDENTIALS"],
"optional": [],
"packages": [
_PkgSpec("langchain-google-vertexai"),
_PkgSpec("google-cloud-aiplatform")
]
},
LLMType.BEDROCK: {
"required": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"],
"optional": [],
"packages": [_PkgSpec("boto3")]
},
LLMType.JINA: {
"required": ["JINA_API_KEY"],
"optional": [],
"packages": [_PkgSpec("jina")]
},

# Vector Databases
VectorDatabase.PINECONE: {
Expand Down
51 changes: 27 additions & 24 deletions src/ragbuilder/core/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def __init__(
self._log_config.log_level,
self._log_config.log_file
)
self._optimized_store = None
self._optimized_retriever = None
self._optimized_generation = None
self.optimized_store = None
self.optimized_retriever = None
self.optimized_generation = None
self._optimization_results = OptimizationResults()
self._test_dataset_manager = TestDatasetManager(
self._log_config,
Expand Down Expand Up @@ -170,7 +170,7 @@ def optimize_data_ingest(
self,
config: Optional[DataIngestOptionsConfig] = None,
validate_env: bool = True
) -> Dict[str, Any]:
) -> DataIngestResults:
"""
Run data ingestion optimization
Expand All @@ -186,7 +186,8 @@ def optimize_data_ingest(
self.data_ingest_config.apply_defaults()

if validate_env:
validate_environment(self.data_ingest_config)
with console.status("[status]Validating data ingestion environment...[/status]"):
validate_environment(self.data_ingest_config)

self._ensure_eval_dataset(self.data_ingest_config)

Expand All @@ -199,9 +200,8 @@ def optimize_data_ingest(

# Store results and update telemetry
self._optimization_results.data_ingest = results
self._optimized_store = results.best_index
self.optimized_store = results.best_index
telemetry.update_optimization_results(span, results, "data_ingest")
return results

except Exception as e:
telemetry.track_error(
Expand All @@ -215,7 +215,7 @@ def optimize_data_ingest(
finally:
telemetry.flush()


return results
def optimize_retrieval(
self,
config: Optional[RetrievalOptionsConfig] = None,
Expand All @@ -233,7 +233,7 @@ def optimize_retrieval(
Returns:
RetrievalResults containing optimization results
"""
vectorstore = vectorstore or self._optimized_store
vectorstore = vectorstore or self.optimized_store
if not vectorstore:
raise DependencyError("No vectorstore found. Run data ingestion first or provide existing vectorstore.")

Expand All @@ -244,23 +244,23 @@ def optimize_retrieval(
self.retrieval_config.apply_defaults()

if validate_env:
validate_environment(self.retrieval_config)
with console.status("[status]Validating retrieval environment...[/status]"):
validate_environment(self.retrieval_config)

self._ensure_eval_dataset(self.retrieval_config)

with telemetry.optimization_span("retriever", self.retrieval_config.model_dump()) as span:
try:
results = run_retrieval_optimization(
self.retrieval_config,
vectorstore=self._optimized_store,
vectorstore=self.optimized_store,
log_config=self._log_config
)

# Store results and update telemetry
self._optimization_results.retrieval = results
self._optimized_retriever = results.best_pipeline.retriever_chain
self.optimized_retriever = results.best_pipeline.retriever_chain
telemetry.update_optimization_results(span, results, "retriever")
return results

except Exception as e:
telemetry.track_error(
Expand All @@ -274,21 +274,23 @@ def optimize_retrieval(
raise
finally:
telemetry.flush()

return results

def optimize_generation(
self,
config: Optional[GenerationOptionsConfig] = None,
retriever: Optional[Any] = None
) -> Dict[str, Any]:
) -> GenerationResults:
"""
Run Generation optimization
Returns:
Dict containing optimization results including best_config, best_score,
best_pipeline, and study_statistics
"""
self._optimized_retriever = retriever or self._optimized_retriever
if not self._optimized_retriever:
self.optimized_retriever = retriever or self.optimized_retriever
if not self.optimized_retriever:
raise DependencyError("No retriever found. Run retrieval optimization first or provide existing retriever.")

self.generation_config = config or GenerationOptionsConfig.with_defaults()
Expand All @@ -303,15 +305,14 @@ def optimize_generation(
try:
results = run_generation_optimization(
self.generation_config,
retriever=self._optimized_retriever,
retriever=self.optimized_retriever,
log_config=self._log_config
)

# Store results and update telemetry
self._optimization_results.generation = results
self._optimized_generation = results.best_pipeline
self.optimized_generation = results.best_pipeline
telemetry.update_optimization_results(span, results, "generation")
return results

except Exception as e:
telemetry.track_error(
Expand All @@ -325,6 +326,8 @@ def optimize_generation(
raise
finally:
telemetry.flush()

return results

def optimize(self) -> OptimizationResults:
"""
Expand All @@ -335,9 +338,9 @@ def optimize(self) -> OptimizationResults:
"""
with telemetry.optimization_span("ragbuilder", {"end_to_end": True}) as span:
try:
with console.status("[bold green]Validating data ingestion environment...") as status:
with console.status("[status]Validating data ingestion environment...[/status]") as status:
validate_environment(self.data_ingest_config)
status.update("[bold green]Validating retrieval environment...")
status.update("[status]Validating retrieval environment...[/status]")
if not self.retrieval_config:
self.retrieval_config = RetrievalOptionsConfig.with_defaults()
validate_environment(self.retrieval_config)
Expand All @@ -354,8 +357,6 @@ def optimize(self) -> OptimizationResults:
span.set_attribute("retrieval_score", self._optimization_results.retrieval.best_score)
if self._optimization_results.generation:
span.set_attribute("generation_score", self._optimization_results.generation.best_score)

return self._optimization_results

except Exception as e:
telemetry.track_error(
Expand All @@ -372,6 +373,8 @@ def optimize(self) -> OptimizationResults:
finally:
telemetry.flush()

return self._optimization_results

def __del__(self):
if telemetry:
try:
Expand Down Expand Up @@ -417,7 +420,7 @@ def serve(self, host: str = "0.0.0.0", port: int = 8005):
@app.post("/invoke")
async def invoke(request: QueryRequest) -> Dict[str, Any]:
try:
result = self._optimized_generation.query(
result = self.optimized_generation.query(
request.get_query()
)
console.print(f"Question:{request.get_query()}")
Expand Down
9 changes: 6 additions & 3 deletions src/ragbuilder/core/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,11 @@ def flush(self):
logger.debug(f"Error flushing telemetry: {e}")

def shutdown(self):
if self.enabled and self.meter_provider:
self.meter_provider.force_flush()
self.meter_provider.shutdown()
try:
if self.enabled and self.meter_provider:
self.meter_provider.force_flush()
self.meter_provider.shutdown()
except Exception as e:
logger.debug(f"Error shutting down telemetry: {e}")

telemetry = RAGBuilderTelemetry()
67 changes: 66 additions & 1 deletion src/ragbuilder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ragbuilder.config.retriever import RetrievalOptionsConfig
import json
import importlib
import requests

logger = logging.getLogger(__name__)
os.environ['USER_AGENT'] = "ragbuilder"
Expand Down Expand Up @@ -76,6 +77,49 @@ def validate_environment(config: Union[DataIngestOptionsConfig, RetrievalOptions
missing_env = []
missing_packages = []

# Validate input_source for DataIngestOptionsConfig
if isinstance(config, DataIngestOptionsConfig):
input_sources = [config.input_source] if isinstance(config.input_source, str) else config.input_source
for source in input_sources:
if not _is_valid_input_source(source):
raise ValueError(f"Invalid input source: {source}")

# Validate test_dataset if provided in evaluation_config
if hasattr(config, 'evaluation_config') and config.evaluation_config:
if config.evaluation_config.test_dataset:
if not os.path.isfile(config.evaluation_config.test_dataset):
raise ValueError(f"Invalid test dataset path: {config.evaluation_config.test_dataset}")

# Validate LLM and embeddings in evaluation_config
if config.evaluation_config.llm and hasattr(config.evaluation_config.llm, 'type'):
_missing_env, _missing_packages = validate_component_env(config.evaluation_config.llm.type)
missing_env.extend(_missing_env)
missing_packages.extend(_missing_packages)

if config.evaluation_config.embeddings and hasattr(config.evaluation_config.embeddings, 'type'):
_missing_env, _missing_packages = validate_component_env(config.evaluation_config.embeddings.type)
missing_env.extend(_missing_env)
missing_packages.extend(_missing_packages)

# Validate eval data generation config if present
if config.evaluation_config.eval_data_generation_config:
gen_config = config.evaluation_config.eval_data_generation_config

if gen_config.generator_model and hasattr(gen_config.generator_model, 'type'):
_missing_env, _missing_packages = validate_component_env(gen_config.generator_model.type)
missing_env.extend(_missing_env)
missing_packages.extend(_missing_packages)

if gen_config.critic_model and hasattr(gen_config.critic_model, 'type'):
_missing_env, _missing_packages = validate_component_env(gen_config.critic_model.type)
missing_env.extend(_missing_env)
missing_packages.extend(_missing_packages)

if gen_config.embedding_model and hasattr(gen_config.embedding_model, 'type'):
_missing_env, _missing_packages = validate_component_env(gen_config.embedding_model.type)
missing_env.extend(_missing_env)
missing_packages.extend(_missing_packages)

if hasattr(config, 'document_loaders'):
for loader in config.document_loaders:
_missing_env, _missing_packages = validate_component_env(loader.type)
Expand Down Expand Up @@ -164,4 +208,25 @@ def serialize_config(config: Any) -> str:
)
except Exception as e:
logger.error(f"Failed to serialize config: {str(e)}")
return str(config) # Fallback to string representation
return str(config) # Fallback to string representation

def _is_valid_input_source(input_path: str) -> bool:
"""
Validate if input source is a valid file, directory, or URL.
Args:
input_path: Path to validate
Returns:
bool: True if valid, False otherwise
"""
# Check if it's a URL
if re.match(r'https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+', input_path):
try:
response = requests.head(input_path)
return response.status_code == 200
except:
return False

# Check if it's a file or directory
return os.path.isfile(input_path) or os.path.isdir(input_path)
Loading

0 comments on commit faec7ca

Please sign in to comment.