diff --git a/src/ragas/testset/graph.py b/src/ragas/testset/graph.py index 94f10aeb2..fbc73c765 100644 --- a/src/ragas/testset/graph.py +++ b/src/ragas/testset/graph.py @@ -1,6 +1,7 @@ import json import typing as t import uuid +from copy import deepcopy from dataclasses import dataclass, field from enum import Enum from pathlib import Path @@ -268,6 +269,53 @@ def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]): return unique_clusters + def remove_node( + self, node: Node, inplace: bool = True + ) -> t.Optional["KnowledgeGraph"]: + """ + Removes a node and its associated relationships from the knowledge graph. + + Parameters + ---------- + node : Node + The node to be removed from the knowledge graph. + inplace : bool, optional + If True, modifies the knowledge graph in place. + If False, returns a modified copy with the node removed. + + Returns + ------- + KnowledgeGraph or None + Returns a modified copy of the knowledge graph if `inplace` is False. + Returns None if `inplace` is True. + + Raises + ------ + ValueError + If the node is not present in the knowledge graph. + """ + if node not in self.nodes: + raise ValueError("Node is not present in the knowledge graph.") + + if inplace: + # Modify the current instance + self.nodes.remove(node) + self.relationships = [ + rel + for rel in self.relationships + if rel.source != node and rel.target != node + ] + else: + # Create a deep copy and modify it + new_graph = deepcopy(self) + new_graph.nodes.remove(node) + new_graph.relationships = [ + rel + for rel in new_graph.relationships + if rel.source != node and rel.target != node + ] + return new_graph + def find_direct_clusters( self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True ) -> t.Dict[Node, t.List[t.Set[Node]]]: diff --git a/src/ragas/testset/graph_queries.py b/src/ragas/testset/graph_queries.py index 23397d803..5e6239d73 100644 --- a/src/ragas/testset/graph_queries.py +++ b/src/ragas/testset/graph_queries.py @@ -36,3 +36,38 @@ def dfs(current_node: Node, current_level: int): dfs(node, 1) return children + + +def get_parent_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]: + """ + Get the parent nodes of a given node up to a specified level. + + Parameters + ---------- + node : Node + The node to get the parents of. + graph : KnowledgeGraph + The knowledge graph containing the node. + level : int + The maximum level to which parent nodes are searched. + + Returns + ------- + List[Node] + The list of parent nodes up to the specified level. + """ + parents = [] + + # Helper function to perform depth-limited search for parent nodes + def dfs(current_node: Node, current_level: int): + if current_level > level: + return + for rel in graph.relationships: + if rel.target == current_node and rel.type == "child": + parents.append(rel.source) + dfs(rel.source, current_level + 1) + + # Start DFS from the initial node at level 0 + dfs(node, 1) + + return parents diff --git a/src/ragas/testset/transforms/__init__.py b/src/ragas/testset/transforms/__init__.py index 1e91ff5dc..ccfe73b2c 100644 --- a/src/ragas/testset/transforms/__init__.py +++ b/src/ragas/testset/transforms/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter +from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter, NodeFilter from .default import default_transforms from .engine import Parallel, Transforms, apply_transforms, rollback_transforms from .extractors import ( @@ -13,6 +13,7 @@ SummaryCosineSimilarityBuilder, ) from .splitters import HeadlineSplitter +from .filters import CustomNodeFilter __all__ = [ # base @@ -37,4 +38,6 @@ "SummaryCosineSimilarityBuilder", # splitters "HeadlineSplitter", + "CustomNodeFilter", + "NodeFilter", ] diff --git a/src/ragas/testset/transforms/base.py b/src/ragas/testset/transforms/base.py index 1e95fd8e1..3c1892c81 100644 --- a/src/ragas/testset/transforms/base.py +++ b/src/ragas/testset/transforms/base.py @@ -322,3 +322,55 @@ async def apply_build_relationships( filtered_kg = self.filter(kg) return [apply_build_relationships(filtered_kg=filtered_kg, original_kg=kg)] + + +@dataclass +class NodeFilter(BaseGraphTransformation): + + async def transform(self, kg: KnowledgeGraph) -> KnowledgeGraph: + + filtered = self.filter(kg) + + for node in filtered.nodes: + flag = await self.custom_filter(node, kg) + if flag: + kg_ = kg.remove_node(node, inplace=False) + if isinstance(kg_, KnowledgeGraph): + return kg_ + else: + raise ValueError("Error in removing node") + return kg + + @abstractmethod + async def custom_filter(self, node: Node, kg: KnowledgeGraph) -> bool: + """ + Abstract method to filter a node based on a prompt. + + Parameters + ---------- + node : Node + The node to be filtered. + + Returns + ------- + bool + A boolean indicating whether the node should be filtered. + """ + pass + + def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]: + """ + Generates a list of coroutines to be executed + """ + + async def apply_filter(node: Node): + if await self.custom_filter(node, kg): + kg.remove_node(node) + + filtered = self.filter(kg) + return [apply_filter(node) for node in filtered.nodes] + + +@dataclass +class LLMBasedNodeFilter(NodeFilter, PromptMixin): + llm: BaseRagasLLM = field(default_factory=llm_factory) diff --git a/src/ragas/testset/transforms/default.py b/src/ragas/testset/transforms/default.py index db58045b5..071c42756 100644 --- a/src/ragas/testset/transforms/default.py +++ b/src/ragas/testset/transforms/default.py @@ -9,6 +9,7 @@ SummaryExtractor, ) from ragas.testset.transforms.extractors.llm_based import NERExtractor, ThemesExtractor +from ragas.testset.transforms.filters import CustomNodeFilter from ragas.testset.transforms.relationship_builders import ( CosineSimilarityBuilder, OverlapScoreBuilder, @@ -82,11 +83,14 @@ def summary_filter(node): threshold=0.01, filter_nodes=lambda node: node.type == NodeType.CHUNK ) + node_filter = CustomNodeFilter(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK) + transforms = [ headline_extractor, splitter, - Parallel(summary_extractor, theme_extractor, ner_extractor), - summary_emb_extractor, + summary_extractor, + node_filter, + Parallel(summary_emb_extractor, theme_extractor, ner_extractor), Parallel(cosine_sim_builder, ner_overlap_sim), ] diff --git a/src/ragas/testset/transforms/extractors/llm_based.py b/src/ragas/testset/transforms/extractors/llm_based.py index 78b7c5f98..83e29c3f6 100644 --- a/src/ragas/testset/transforms/extractors/llm_based.py +++ b/src/ragas/testset/transforms/extractors/llm_based.py @@ -70,7 +70,7 @@ class Headlines(BaseModel): class HeadlinesExtractorPrompt(PydanticPrompt[StringIO, Headlines]): - instruction: str = "Extract only level 2 headings from the given text." + instruction: str = "Extract only level 2 and level 3 headings from the given text." input_model: t.Type[StringIO] = StringIO output_model: t.Type[Headlines] = Headlines @@ -78,27 +78,36 @@ class HeadlinesExtractorPrompt(PydanticPrompt[StringIO, Headlines]): ( StringIO( text="""\ - Introduction - Overview of the topic... + Introduction + Overview of the topic... - Main Concepts - Explanation of core ideas... + Main Concepts + Explanation of core ideas... - Detailed Analysis - Techniques and methods for analysis... + Detailed Analysis + Techniques and methods for analysis... - Subsection: Specialized Techniques - Further details on specialized techniques... + Subsection: Specialized Techniques + Further details on specialized techniques... - Future Directions - Insights into upcoming trends... + Future Directions + Insights into upcoming trends... - Conclusion - Final remarks and summary. - """, + Subsection: Next Steps in Research + Discussion of new areas of study... + + Conclusion + Final remarks and summary. + """ ), Headlines( - headlines=["Main Concepts", "Detailed Analysis", "Future Directions"] + headlines=[ + "Main Concepts", + "Detailed Analysis", + "Subsection: Specialized Techniques", + "Future Directions", + "Subsection: Next Steps in Research", + ] ), ), ] @@ -108,15 +117,24 @@ class NEROutput(BaseModel): entities: t.List[str] -class NERPrompt(PydanticPrompt[StringIO, NEROutput]): - instruction: str = "Extract named entities from the given text." - input_model: t.Type[StringIO] = StringIO +class TextWithExtractionLimit(BaseModel): + text: str + max_num: int = 10 + + +class NERPrompt(PydanticPrompt[TextWithExtractionLimit, NEROutput]): + instruction: str = ( + "Extract the named entities from the given text, limiting the output to the top entities. " + "Ensure the number of entities does not exceed the specified maximum." + ) + input_model: t.Type[TextWithExtractionLimit] = TextWithExtractionLimit output_model: t.Type[NEROutput] = NEROutput - examples: t.List[t.Tuple[StringIO, NEROutput]] = [ + examples: t.List[t.Tuple[TextWithExtractionLimit, NEROutput]] = [ ( - StringIO( + TextWithExtractionLimit( text="""Elon Musk, the CEO of Tesla and SpaceX, announced plans to expand operations to new locations in Europe and Asia. - This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai.""" + This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai.""", + max_num=10, ), NEROutput( entities=[ @@ -246,12 +264,16 @@ class NERExtractor(LLMBasedExtractor): property_name: str = "entities" prompt: NERPrompt = NERPrompt() + max_num_entities: int = 10 async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, [] - result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + result = await self.prompt.generate( + self.llm, + data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_entities), + ) return self.property_name, result.entities @@ -305,14 +327,17 @@ class ThemesAndConcepts(BaseModel): output: t.List[str] -class ThemesAndConceptsExtractorPrompt(PydanticPrompt[StringIO, ThemesAndConcepts]): +class ThemesAndConceptsExtractorPrompt( + PydanticPrompt[TextWithExtractionLimit, ThemesAndConcepts] +): instruction: str = "Extract the main themes and concepts from the given text." - input_model: t.Type[StringIO] = StringIO + input_model: t.Type[TextWithExtractionLimit] = TextWithExtractionLimit output_model: t.Type[ThemesAndConcepts] = ThemesAndConcepts - examples: t.List[t.Tuple[StringIO, ThemesAndConcepts]] = [ + examples: t.List[t.Tuple[TextWithExtractionLimit, ThemesAndConcepts]] = [ ( - StringIO( - text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations." + TextWithExtractionLimit( + text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations.", + max_num=10, ), ThemesAndConcepts( output=[ @@ -343,10 +368,14 @@ class ThemesExtractor(LLMBasedExtractor): property_name: str = "themes" prompt: ThemesAndConceptsExtractorPrompt = ThemesAndConceptsExtractorPrompt() + max_num_themes: int = 10 async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]: node_text = node.get_property("page_content") if node_text is None: return self.property_name, [] - result = await self.prompt.generate(self.llm, data=StringIO(text=node_text)) + result = await self.prompt.generate( + self.llm, + data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_themes), + ) return self.property_name, result.output diff --git a/src/ragas/testset/transforms/filters.py b/src/ragas/testset/transforms/filters.py new file mode 100644 index 000000000..44370fde4 --- /dev/null +++ b/src/ragas/testset/transforms/filters.py @@ -0,0 +1,81 @@ +import logging +import typing as t +from dataclasses import dataclass, field + +from pydantic import BaseModel, Field + +from ragas.prompt import PydanticPrompt +from ragas.testset.graph import KnowledgeGraph, Node +from ragas.testset.graph_queries import get_parent_nodes +from ragas.testset.transforms.base import LLMBasedNodeFilter + +logger = logging.getLogger(__name__) + + +DEFAULT_RUBRICS = { + "score1_description": "The page content is irrelevant or does not align with the main themes or topics of the document summary.", + "score2_description": "The page content partially aligns with the document summary, but it includes unrelated details or lacks critical information related to the document's main themes.", + "score3_description": "The page content generally reflects the document summary but may miss key details or lack depth in addressing the main themes.", + "score4_description": "The page content aligns well with the document summary, covering the main themes and topics with minor gaps or minimal unrelated information.", + "score5_description": "The page content is highly relevant, accurate, and directly reflects the main themes of the document summary, covering all important details and adding depth to the understanding of the document's topics.", +} + + +class QuestionPotentialInput(BaseModel): + document_summary: str = Field( + ..., + description="The summary of the document to provide context for evaluating the node.", + ) + node_content: str = Field( + ..., + description="The content of the node to evaluate for question generation potential.", + ) + rubrics: t.Dict[str, str] = Field(..., description="The rubric") + + +class QuestionPotentialOutput(BaseModel): + score: int = Field( + ..., + description="1 to 5 score", + ) + + +class QuestionPotentialPrompt( + PydanticPrompt[QuestionPotentialInput, QuestionPotentialOutput] +): + instruction = ( + "Given a document summary and node content, score the content of the node in 1 to 5 range." + "" + ) + input_model = QuestionPotentialInput + output_model = QuestionPotentialOutput + + +@dataclass +class CustomNodeFilter(LLMBasedNodeFilter): + """ + returns True if the score is less than min_score + """ + + scoring_prompt: PydanticPrompt = field(default_factory=QuestionPotentialPrompt) + min_score: int = 2 + rubrics: t.Dict[str, str] = field(default_factory=lambda: DEFAULT_RUBRICS) + + async def custom_filter(self, node: Node, kg: KnowledgeGraph) -> bool: + + parent_nodes = get_parent_nodes(node, kg) + if len(parent_nodes) > 0: + summary = parent_nodes[0].properties.get("summary", "") + else: + summary = "" + + if summary == "": + logger.warning(f"Node {node} has no parent node with a summary.") + + prompt_input = QuestionPotentialInput( + document_summary=summary, + node_content=node.properties.get("page_content", ""), + rubrics=self.rubrics, + ) + response = await self.scoring_prompt.generate(data=prompt_input, llm=self.llm) + return response.score <= self.min_score diff --git a/src/ragas/testset/transforms/splitters/headline.py b/src/ragas/testset/transforms/splitters/headline.py index 3dae5e763..2f0a58910 100644 --- a/src/ragas/testset/transforms/splitters/headline.py +++ b/src/ragas/testset/transforms/splitters/headline.py @@ -3,12 +3,45 @@ from ragas.testset.graph import Node, NodeType, Relationship from ragas.testset.transforms.base import Splitter -from ragas.utils import num_tokens_from_string @dataclass class HeadlineSplitter(Splitter): min_tokens: int = 300 + max_tokens: int = 1000 + + def adjust_chunks(self, chunks): + adjusted_chunks = [] + current_chunk = "" + + for chunk in chunks: + chunk_tokens = chunk.split() + + # Split chunks that are over max_tokens + while len(chunk_tokens) > self.max_tokens: + adjusted_chunks.append(" ".join(chunk_tokens[: self.max_tokens])) + chunk_tokens = chunk_tokens[self.max_tokens :] + + # Handle chunks that are under min_tokens + if len(chunk_tokens) < self.min_tokens: + if current_chunk: + current_chunk += " " + " ".join(chunk_tokens) + if len(current_chunk.split()) >= self.min_tokens: + adjusted_chunks.append(current_chunk) + current_chunk = "" + else: + current_chunk = " ".join(chunk_tokens) + else: + if current_chunk: + adjusted_chunks.append(current_chunk) + current_chunk = "" + adjusted_chunks.append(" ".join(chunk_tokens)) + + # Append any remaining chunk + if current_chunk: + adjusted_chunks.append(current_chunk) + + return adjusted_chunks async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]: text = node.get_property("page_content") @@ -27,19 +60,7 @@ async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]] indices.append(index) indices.append(len(text)) chunks = [text[indices[i] : indices[i + 1]] for i in range(len(indices) - 1)] - # merge chunks if their length is less than 300 tokens - merged_chunks = [] - current_chunk = chunks[0] - - for next_chunk in chunks[1:]: - if num_tokens_from_string(current_chunk) < self.min_tokens: - current_chunk = "\n\n".join([current_chunk, next_chunk]) - else: - merged_chunks.append(current_chunk) - current_chunk = next_chunk - - merged_chunks.append(current_chunk) - chunks = merged_chunks + chunks = self.adjust_chunks(chunks) # if there was no headline, return the original node if len(chunks) == 1: