Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improvements in test gen #1645

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/ragas/testset/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import typing as t
import uuid
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -268,6 +269,53 @@ def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]):

return unique_clusters

def remove_node(
self, node: Node, inplace: bool = True
) -> t.Optional["KnowledgeGraph"]:
"""
Removes a node and its associated relationships from the knowledge graph.

Parameters
----------
node : Node
The node to be removed from the knowledge graph.
inplace : bool, optional
If True, modifies the knowledge graph in place.
If False, returns a modified copy with the node removed.

Returns
-------
KnowledgeGraph or None
Returns a modified copy of the knowledge graph if `inplace` is False.
Returns None if `inplace` is True.

Raises
------
ValueError
If the node is not present in the knowledge graph.
"""
if node not in self.nodes:
raise ValueError("Node is not present in the knowledge graph.")

if inplace:
# Modify the current instance
self.nodes.remove(node)
self.relationships = [
rel
for rel in self.relationships
if rel.source != node and rel.target != node
]
else:
# Create a deep copy and modify it
new_graph = deepcopy(self)
new_graph.nodes.remove(node)
new_graph.relationships = [
rel
for rel in new_graph.relationships
if rel.source != node and rel.target != node
]
return new_graph

def find_direct_clusters(
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
) -> t.Dict[Node, t.List[t.Set[Node]]]:
Expand Down
35 changes: 35 additions & 0 deletions src/ragas/testset/graph_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,38 @@ def dfs(current_node: Node, current_level: int):
dfs(node, 1)

return children


def get_parent_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]:
"""
Get the parent nodes of a given node up to a specified level.

Parameters
----------
node : Node
The node to get the parents of.
graph : KnowledgeGraph
The knowledge graph containing the node.
level : int
The maximum level to which parent nodes are searched.

Returns
-------
List[Node]
The list of parent nodes up to the specified level.
"""
parents = []

# Helper function to perform depth-limited search for parent nodes
def dfs(current_node: Node, current_level: int):
if current_level > level:
return
for rel in graph.relationships:
if rel.target == current_node and rel.type == "child":
parents.append(rel.source)
dfs(rel.source, current_level + 1)

# Start DFS from the initial node at level 0
dfs(node, 1)

return parents
5 changes: 4 additions & 1 deletion src/ragas/testset/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter
from .base import BaseGraphTransformation, Extractor, RelationshipBuilder, Splitter, NodeFilter
from .default import default_transforms
from .engine import Parallel, Transforms, apply_transforms, rollback_transforms
from .extractors import (
Expand All @@ -13,6 +13,7 @@
SummaryCosineSimilarityBuilder,
)
from .splitters import HeadlineSplitter
from .filters import CustomNodeFilter

__all__ = [
# base
Expand All @@ -37,4 +38,6 @@
"SummaryCosineSimilarityBuilder",
# splitters
"HeadlineSplitter",
"CustomNodeFilter",
"NodeFilter",
]
52 changes: 52 additions & 0 deletions src/ragas/testset/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,55 @@ async def apply_build_relationships(

filtered_kg = self.filter(kg)
return [apply_build_relationships(filtered_kg=filtered_kg, original_kg=kg)]


@dataclass
class NodeFilter(BaseGraphTransformation):

async def transform(self, kg: KnowledgeGraph) -> KnowledgeGraph:

filtered = self.filter(kg)

for node in filtered.nodes:
flag = await self.custom_filter(node, kg)
if flag:
kg_ = kg.remove_node(node, inplace=False)
if isinstance(kg_, KnowledgeGraph):
return kg_
else:
raise ValueError("Error in removing node")
return kg

@abstractmethod
async def custom_filter(self, node: Node, kg: KnowledgeGraph) -> bool:
"""
Abstract method to filter a node based on a prompt.

Parameters
----------
node : Node
The node to be filtered.

Returns
-------
bool
A boolean indicating whether the node should be filtered.
"""
pass

def generate_execution_plan(self, kg: KnowledgeGraph) -> t.List[t.Coroutine]:
"""
Generates a list of coroutines to be executed
"""

async def apply_filter(node: Node):
if await self.custom_filter(node, kg):
kg.remove_node(node)

filtered = self.filter(kg)
return [apply_filter(node) for node in filtered.nodes]


@dataclass
class LLMBasedNodeFilter(NodeFilter, PromptMixin):
llm: BaseRagasLLM = field(default_factory=llm_factory)
8 changes: 6 additions & 2 deletions src/ragas/testset/transforms/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SummaryExtractor,
)
from ragas.testset.transforms.extractors.llm_based import NERExtractor, ThemesExtractor
from ragas.testset.transforms.filters import CustomNodeFilter
from ragas.testset.transforms.relationship_builders import (
CosineSimilarityBuilder,
OverlapScoreBuilder,
Expand Down Expand Up @@ -82,11 +83,14 @@ def summary_filter(node):
threshold=0.01, filter_nodes=lambda node: node.type == NodeType.CHUNK
)

node_filter = CustomNodeFilter(llm=llm, filter_nodes=lambda node: node.type == NodeType.CHUNK)

transforms = [
headline_extractor,
splitter,
Parallel(summary_extractor, theme_extractor, ner_extractor),
summary_emb_extractor,
summary_extractor,
node_filter,
Parallel(summary_emb_extractor, theme_extractor, ner_extractor),
Parallel(cosine_sim_builder, ner_overlap_sim),
]

Expand Down
85 changes: 57 additions & 28 deletions src/ragas/testset/transforms/extractors/llm_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,44 @@ class Headlines(BaseModel):


class HeadlinesExtractorPrompt(PydanticPrompt[StringIO, Headlines]):
instruction: str = "Extract only level 2 headings from the given text."
instruction: str = "Extract only level 2 and level 3 headings from the given text."

input_model: t.Type[StringIO] = StringIO
output_model: t.Type[Headlines] = Headlines
examples: t.List[t.Tuple[StringIO, Headlines]] = [
(
StringIO(
text="""\
Introduction
Overview of the topic...
Introduction
Overview of the topic...

Main Concepts
Explanation of core ideas...
Main Concepts
Explanation of core ideas...

Detailed Analysis
Techniques and methods for analysis...
Detailed Analysis
Techniques and methods for analysis...

Subsection: Specialized Techniques
Further details on specialized techniques...
Subsection: Specialized Techniques
Further details on specialized techniques...

Future Directions
Insights into upcoming trends...
Future Directions
Insights into upcoming trends...

Conclusion
Final remarks and summary.
""",
Subsection: Next Steps in Research
Discussion of new areas of study...

Conclusion
Final remarks and summary.
"""
),
Headlines(
headlines=["Main Concepts", "Detailed Analysis", "Future Directions"]
headlines=[
"Main Concepts",
"Detailed Analysis",
"Subsection: Specialized Techniques",
"Future Directions",
"Subsection: Next Steps in Research",
]
),
),
]
Expand All @@ -108,15 +117,24 @@ class NEROutput(BaseModel):
entities: t.List[str]


class NERPrompt(PydanticPrompt[StringIO, NEROutput]):
instruction: str = "Extract named entities from the given text."
input_model: t.Type[StringIO] = StringIO
class TextWithExtractionLimit(BaseModel):
text: str
max_num: int = 10


class NERPrompt(PydanticPrompt[TextWithExtractionLimit, NEROutput]):
instruction: str = (
"Extract the named entities from the given text, limiting the output to the top entities. "
"Ensure the number of entities does not exceed the specified maximum."
)
input_model: t.Type[TextWithExtractionLimit] = TextWithExtractionLimit
output_model: t.Type[NEROutput] = NEROutput
examples: t.List[t.Tuple[StringIO, NEROutput]] = [
examples: t.List[t.Tuple[TextWithExtractionLimit, NEROutput]] = [
(
StringIO(
TextWithExtractionLimit(
text="""Elon Musk, the CEO of Tesla and SpaceX, announced plans to expand operations to new locations in Europe and Asia.
This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai."""
This expansion is expected to create thousands of jobs, particularly in cities like Berlin and Shanghai.""",
max_num=10,
),
NEROutput(
entities=[
Expand Down Expand Up @@ -246,12 +264,16 @@ class NERExtractor(LLMBasedExtractor):

property_name: str = "entities"
prompt: NERPrompt = NERPrompt()
max_num_entities: int = 10

async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, []
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
result = await self.prompt.generate(
self.llm,
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_entities),
)
return self.property_name, result.entities


Expand Down Expand Up @@ -305,14 +327,17 @@ class ThemesAndConcepts(BaseModel):
output: t.List[str]


class ThemesAndConceptsExtractorPrompt(PydanticPrompt[StringIO, ThemesAndConcepts]):
class ThemesAndConceptsExtractorPrompt(
PydanticPrompt[TextWithExtractionLimit, ThemesAndConcepts]
):
instruction: str = "Extract the main themes and concepts from the given text."
input_model: t.Type[StringIO] = StringIO
input_model: t.Type[TextWithExtractionLimit] = TextWithExtractionLimit
output_model: t.Type[ThemesAndConcepts] = ThemesAndConcepts
examples: t.List[t.Tuple[StringIO, ThemesAndConcepts]] = [
examples: t.List[t.Tuple[TextWithExtractionLimit, ThemesAndConcepts]] = [
(
StringIO(
text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations."
TextWithExtractionLimit(
text="Artificial intelligence is transforming industries by automating tasks requiring human intelligence. AI analyzes vast data quickly and accurately, driving innovations like self-driving cars and personalized recommendations.",
max_num=10,
),
ThemesAndConcepts(
output=[
Expand Down Expand Up @@ -343,10 +368,14 @@ class ThemesExtractor(LLMBasedExtractor):

property_name: str = "themes"
prompt: ThemesAndConceptsExtractorPrompt = ThemesAndConceptsExtractorPrompt()
max_num_themes: int = 10

async def extract(self, node: Node) -> t.Tuple[str, t.List[str]]:
node_text = node.get_property("page_content")
if node_text is None:
return self.property_name, []
result = await self.prompt.generate(self.llm, data=StringIO(text=node_text))
result = await self.prompt.generate(
self.llm,
data=TextWithExtractionLimit(text=node_text, max_num=self.max_num_themes),
)
return self.property_name, result.output
Loading
Loading